Download Install Tutorial Docs FAQ Tools WikiLicense Team IRC Planet Involvement Shop Book

root/trunk/cherrypy/test/webtest.py

Revision 1901 (checked in by fumanchu, 6 months ago)

I suppose defaults don't mean much for property setters. ;)

  • Property svn:eol-style set to native
Line 
1 """Extensions to unittest for web frameworks.
2
3 Use the WebCase.getPage method to request a page from your HTTP server.
4
5 Framework Integration
6 =====================
7
8 If you have control over your server process, you can handle errors
9 in the server-side of the HTTP conversation a bit better. You must run
10 both the client (your WebCase tests) and the server in the same process
11 (but in separate threads, obviously).
12
13 When an error occurs in the framework, call server_error. It will print
14 the traceback to stdout, and keep any assertions you have from running
15 (the assumption is that, if the server errors, the page output will not
16 be of further significance to your tests).
17 """
18
19 import httplib
20 import os
21 import pprint
22 import re
23 import socket
24 import sys
25 import time
26 import traceback
27 import types
28
29 from unittest import *
30 from unittest import _TextTestResult
31
32
33 class TerseTestResult(_TextTestResult):
34    
35     def printErrors(self):
36         # Overridden to avoid unnecessary empty line
37         if self.errors or self.failures:
38             if self.dots or self.showAll:
39                 self.stream.writeln()
40             self.printErrorList('ERROR', self.errors)
41             self.printErrorList('FAIL', self.failures)
42
43
44 class TerseTestRunner(TextTestRunner):
45     """A test runner class that displays results in textual form."""
46    
47     def _makeResult(self):
48         return TerseTestResult(self.stream, self.descriptions, self.verbosity)
49    
50     def run(self, test):
51         "Run the given test case or test suite."
52         # Overridden to remove unnecessary empty lines and separators
53         result = self._makeResult()
54         test(result)
55         result.printErrors()
56         if not result.wasSuccessful():
57             self.stream.write("FAILED (")
58             failed, errored = map(len, (result.failures, result.errors))
59             if failed:
60                 self.stream.write("failures=%d" % failed)
61             if errored:
62                 if failed: self.stream.write(", ")
63                 self.stream.write("errors=%d" % errored)
64             self.stream.writeln(")")
65         return result
66
67
68 class ReloadingTestLoader(TestLoader):
69    
70     def loadTestsFromName(self, name, module=None):
71         """Return a suite of all tests cases given a string specifier.
72
73         The name may resolve either to a module, a test case class, a
74         test method within a test case class, or a callable object which
75         returns a TestCase or TestSuite instance.
76
77         The method optionally resolves the names relative to a given module.
78         """
79         parts = name.split('.')
80         if module is None:
81             if not parts:
82                 raise ValueError("incomplete test name: %s" % name)
83             else:
84                 parts_copy = parts[:]
85                 while parts_copy:
86                     target = ".".join(parts_copy)
87                     if target in sys.modules:
88                         module = reload(sys.modules[target])
89                         break
90                     else:
91                         try:
92                             module = __import__(target)
93                             break
94                         except ImportError:
95                             del parts_copy[-1]
96                             if not parts_copy:
97                                 raise
98                 parts = parts[1:]
99         obj = module
100         for part in parts:
101             obj = getattr(obj, part)
102        
103         if type(obj) == types.ModuleType:
104             return self.loadTestsFromModule(obj)
105         elif (isinstance(obj, (type, types.ClassType)) and
106               issubclass(obj, TestCase)):
107             return self.loadTestsFromTestCase(obj)
108         elif type(obj) == types.UnboundMethodType:
109             return obj.im_class(obj.__name__)
110         elif callable(obj):
111             test = obj()
112             if not isinstance(test, TestCase) and \
113                not isinstance(test, TestSuite):
114                 raise ValueError("calling %s returned %s, "
115                                  "not a test" % (obj,test))
116             return test
117         else:
118             raise ValueError("do not know how to make test from: %s" % obj)
119
120
121 try:
122     # On Windows, msvcrt.getch reads a single char without output.
123     import msvcrt
124     def getchar():
125         return msvcrt.getch()
126 except ImportError:
127     # Unix getchr
128     import tty, termios
129     def getchar():
130         fd = sys.stdin.fileno()
131         old_settings = termios.tcgetattr(fd)
132         try:
133             tty.setraw(sys.stdin.fileno())
134             ch = sys.stdin.read(1)
135         finally:
136             termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
137         return ch
138
139
140 class WebCase(TestCase):
141     HOST = "127.0.0.1"
142     PORT = 8000
143     HTTP_CONN = httplib.HTTPConnection
144     PROTOCOL = "HTTP/1.1"
145    
146     scheme = "http"
147     url = None
148    
149     status = None
150     headers = None
151     body = None
152     time = None
153    
154     def set_persistent(self, on=True, auto_open=False):
155         """Make our HTTP_CONN persistent (or not).
156         
157         If the 'on' argument is True (the default), then self.HTTP_CONN
158         will be set to an instance of httplib.HTTPConnection (or HTTPS
159         if self.scheme is "https"). This will then persist across requests.
160         
161         We only allow for a single open connection, so if you call this
162         and we currently have an open connection, it will be closed.
163         """
164         try:
165             self.HTTP_CONN.close()
166         except (TypeError, AttributeError):
167             pass
168        
169         if self.scheme == "https":
170             cls = httplib.HTTPSConnection
171         else:
172             cls = httplib.HTTPConnection
173        
174         if on:
175             host = self.HOST
176             if host == '0.0.0.0':
177                 # INADDR_ANY, which should respond on localhost.
178                 host = "127.0.0.1"
179             elif host == '::':
180                 # IN6ADDR_ANY, which should respond on localhost.
181                 host = "::1"
182             self.HTTP_CONN = cls(host, self.PORT)
183             # Automatically re-connect?
184             self.HTTP_CONN.auto_open = auto_open
185             self.HTTP_CONN.connect()
186         else:
187             self.HTTP_CONN = cls
188    
189     def _get_persistent(self):
190         return hasattr(self.HTTP_CONN, "__class__")
191     def _set_persistent(self, on):
192         self.set_persistent(on)
193     persistent = property(_get_persistent, _set_persistent)
194    
195     def getPage(self, url, headers=None, method="GET", body=None, protocol=None):
196         """Open the url with debugging support. Return status, headers, body."""
197         ServerError.on = False
198        
199         self.url = url
200         self.time = None
201         start = time.time()
202         result = openURL(url, headers, method, body, self.HOST, self.PORT,
203                          self.HTTP_CONN, protocol or self.PROTOCOL)
204         self.time = time.time() - start
205         self.status, self.headers, self.body = result
206        
207         # Build a list of request cookies from the previous response cookies.
208         self.cookies = [('Cookie', v) for k, v in self.headers
209                         if k.lower() == 'set-cookie']
210        
211         if ServerError.on:
212             raise ServerError()
213         return result
214    
215     interactive = True
216     console_height = 30
217    
218     def _handlewebError(self, msg):
219         print
220         print "    ERROR:", msg
221        
222         if not self.interactive:
223             raise self.failureException(msg)
224        
225         p = "    Show: [B]ody [H]eaders [S]tatus [U]RL; [I]gnore, [R]aise, or sys.e[X]it >> "
226         print p,
227         while True:
228             i = getchar().upper()
229             if i not in "BHSUIRX":
230                 continue
231             print i.upper()  # Also prints new line
232             if i == "B":
233                 for x, line in enumerate(self.body.splitlines()):
234                     if (x + 1) % self.console_height == 0:
235                         # The \r and comma should make the next line overwrite
236                         print "<-- More -->\r",
237                         m = getchar().lower()
238                         # Erase our "More" prompt
239                         print "            \r",
240                         if m == "q":
241                             break
242                     print line
243             elif i == "H":
244                 pprint.pprint(self.headers)
245             elif i == "S":
246                 print self.status
247             elif i == "U":
248                 print self.url
249             elif i == "I":
250                 # return without raising the normal exception
251                 return
252             elif i == "R":
253                 raise self.failureException(msg)
254             elif i == "X":
255                 self.exit()
256             print p,
257    
258     def exit(self):
259         sys.exit()
260    
261     if sys.version_info >= (2, 5):
262         def __call__(self, result=None):
263             if result is None:
264                 result = self.defaultTestResult()
265             result.startTest(self)
266             testMethod = getattr(self, self._testMethodName)
267             try:
268                 try:
269                     self.setUp()
270                 except (KeyboardInterrupt, SystemExit):
271                     raise
272                 except:
273                     result.addError(self, self._exc_info())
274                     return
275                
276                 ok = 0
277                 try:
278                     testMethod()
279                     ok = 1
280                 except self.failureException:
281                     result.addFailure(self, self._exc_info())
282                 except (KeyboardInterrupt, SystemExit):
283                     raise
284                 except:
285                     result.addError(self, self._exc_info())
286                
287                 try:
288                     self.tearDown()
289                 except (KeyboardInterrupt, SystemExit):
290                     raise
291                 except:
292                     result.addError(self, self._exc_info())
293                     ok = 0
294                 if ok:
295                     result.addSuccess(self)
296             finally:
297                 result.stopTest(self)
298     else:
299         def __call__(self, result=None):
300             if result is None:
301                 result = self.defaultTestResult()
302             result.startTest(self)
303             testMethod = getattr(self, self._TestCase__testMethodName)
304             try:
305                 try:
306                     self.setUp()
307                 except (KeyboardInterrupt, SystemExit):
308                     raise
309                 except:
310                     result.addError(self, self._TestCase__exc_info())
311                     return
312                
313                 ok = 0
314                 try:
315                     testMethod()
316                     ok = 1
317                 except self.failureException:
318                     result.addFailure(self, self._TestCase__exc_info())
319                 except (KeyboardInterrupt, SystemExit):
320                     raise
321                 except:
322                     result.addError(self, self._TestCase__exc_info())
323                
324                 try:
325                     self.tearDown()
326                 except (KeyboardInterrupt, SystemExit):
327                     raise
328                 except:
329                     result.addError(self, self._TestCase__exc_info())
330                     ok = 0
331                 if ok:
332                     result.addSuccess(self)
333             finally:
334                 result.stopTest(self)
335    
336     def assertStatus(self, status, msg=None):
337         """Fail if self.status != status."""
338         if isinstance(status, basestring):
339             if not self.status == status:
340                 if msg is None:
341                     msg = 'Status (%r) != %r' % (self.status, status)
342                 self._handlewebError(msg)
343         elif isinstance(status, int):
344             code = int(self.status[:3])
345             if code != status:
346                 if msg is None:
347                     msg = 'Status (%r) != %r' % (self.status, status)
348                 self._handlewebError(msg)
349         else:
350             # status is a tuple or list.
351             match = False
352             for s in status:
353                 if isinstance(s, basestring):
354                     if self.status == s:
355                         match = True
356                         break
357                 elif int(self.status[:3]) == s:
358                     match = True
359                     break
360             if not match:
361                 if msg is None:
362                     msg = 'Status (%r) not in %r' % (self.status, status)
363                 self._handlewebError(msg)
364    
365     def assertHeader(self, key, value=None, msg=None):
366         """Fail if (key, [value]) not in self.headers."""
367         lowkey = key.lower()
368         for k, v in self.headers:
369             if k.lower() == lowkey:
370                 if value is None or str(value) == v:
371                     return v
372        
373         if msg is None:
374             if value is None:
375                 msg = '%r not in headers' % key
376             else:
377                 msg = '%r:%r not in headers' % (key, value)
378         self._handlewebError(msg)
379    
380     def assertNoHeader(self, key, msg=None):
381         """Fail if key in self.headers."""
382         lowkey = key.lower()
383         matches = [k for k, v in self.headers if k.lower() == lowkey]
384         if matches:
385             if msg is None:
386                 msg = '%r in headers' % key
387             self._handlewebError(msg)
388    
389     def assertBody(self, value, msg=None):
390         """Fail if value != self.body."""
391         if value != self.body:
392             if msg is None:
393                 msg = 'expected body:\n%r\n\nactual body:\n%r' % (value, self.body)
394             self._handlewebError(msg)
395    
396     def assertInBody(self, value, msg=None):
397         """Fail if value not in self.body."""
398         if value not in self.body:
399             if msg is None:
400                 msg = '%r not in body' % value
401             self._handlewebError(msg)
402    
403     def assertNotInBody(self, value, msg=None):
404         """Fail if value in self.body."""
405         if value in self.body:
406             if msg is None:
407                 msg = '%r found in body' % value
408             self._handlewebError(msg)
409    
410     def assertMatchesBody(self, pattern, msg=None, flags=0):
411         """Fail if value (a regex pattern) is not in self.body."""
412         if re.search(pattern, self.body, flags) is None:
413             if msg is None:
414                 msg = 'No match for %r in body' % pattern
415             self._handlewebError(msg)
416
417
418 methods_with_bodies = ("POST", "PUT")
419
420 def cleanHeaders(headers, method, body, host, port):
421     """Return request headers, with required headers added (if missing)."""
422     if headers is None:
423         headers = []
424    
425     # Add the required Host request header if not present.
426     # [This specifies the host:port of the server, not the client.]
427     found = False
428     for k, v in headers:
429         if k.lower() == 'host':
430             found = True
431             break
432     if not found:
433         if port == 80:
434             headers.append(("Host", host))
435         else:
436             headers.append(("Host", "%s:%s" % (host, port)))
437    
438     if method in methods_with_bodies:
439         # Stick in default type and length headers if not present
440         found = False
441         for k, v in headers:
442             if k.lower() == 'content-type':
443                 found = True
444                 break
445         if not found:
446             headers.append(("Content-Type", "application/x-www-form-urlencoded"))
447             headers.append(("Content-Length", str(len(body or ""))))
448    
449     return headers
450
451
452 def shb(response):
453     """Return status, headers, body the way we like from a response."""
454     h = []
455     key, value = None, None
456     for line in response.msg.headers:
457         if line:
458             if line[0] in " \t":
459                 value += line.strip()
460             else:
461                 if key and value:
462                     h.append((key, value))
463                 key, value = line.split(":", 1)
464                 key = key.strip()
465                 value = value.strip()
466     if key and value:
467         h.append((key, value))
468    
469     return "%s %s" % (response.status, response.reason), h, response.read()
470
471
472 def openURL(url, headers=None, method="GET", body=None,
473             host="127.0.0.1", port=8000, http_conn=httplib.HTTPConnection,
474             protocol="HTTP/1.1"):
475     """Open the given HTTP resource and return status, headers, and body."""
476    
477     headers = cleanHeaders(headers, method, body, host, port)
478    
479     # Trying 10 times is simply in case of socket errors.
480</