Download Install Tutorial Docs FAQ Tools WikiLicense Team IRC Planet Involvement Shop Book

root/trunk/cherrypy/test/webtest.py

Revision 2076 (checked in by fumanchu, 3 weeks ago)

Whew. Fixed the whole test suite to properly handle the --host arg.

  • Property svn:eol-style set to native
Line 
1 """Extensions to unittest for web frameworks.
2
3 Use the WebCase.getPage method to request a page from your HTTP server.
4
5 Framework Integration
6 =====================
7
8 If you have control over your server process, you can handle errors
9 in the server-side of the HTTP conversation a bit better. You must run
10 both the client (your WebCase tests) and the server in the same process
11 (but in separate threads, obviously).
12
13 When an error occurs in the framework, call server_error. It will print
14 the traceback to stdout, and keep any assertions you have from running
15 (the assumption is that, if the server errors, the page output will not
16 be of further significance to your tests).
17 """
18
19 import httplib
20 import os
21 import pprint
22 import re
23 import socket
24 import sys
25 import time
26 import traceback
27 import types
28
29 from unittest import *
30 from unittest import _TextTestResult
31
32
33 class TerseTestResult(_TextTestResult):
34    
35     def printErrors(self):
36         # Overridden to avoid unnecessary empty line
37         if self.errors or self.failures:
38             if self.dots or self.showAll:
39                 self.stream.writeln()
40             self.printErrorList('ERROR', self.errors)
41             self.printErrorList('FAIL', self.failures)
42
43
44 class TerseTestRunner(TextTestRunner):
45     """A test runner class that displays results in textual form."""
46    
47     def _makeResult(self):
48         return TerseTestResult(self.stream, self.descriptions, self.verbosity)
49    
50     def run(self, test):
51         "Run the given test case or test suite."
52         # Overridden to remove unnecessary empty lines and separators
53         result = self._makeResult()
54         test(result)
55         result.printErrors()
56         if not result.wasSuccessful():
57             self.stream.write("FAILED (")
58             failed, errored = map(len, (result.failures, result.errors))
59             if failed:
60                 self.stream.write("failures=%d" % failed)
61             if errored:
62                 if failed: self.stream.write(", ")
63                 self.stream.write("errors=%d" % errored)
64             self.stream.writeln(")")
65         return result
66
67
68 class ReloadingTestLoader(TestLoader):
69    
70     def loadTestsFromName(self, name, module=None):
71         """Return a suite of all tests cases given a string specifier.
72
73         The name may resolve either to a module, a test case class, a
74         test method within a test case class, or a callable object which
75         returns a TestCase or TestSuite instance.
76
77         The method optionally resolves the names relative to a given module.
78         """
79         parts = name.split('.')
80         if module is None:
81             if not parts:
82                 raise ValueError("incomplete test name: %s" % name)
83             else:
84                 parts_copy = parts[:]
85                 while parts_copy:
86                     target = ".".join(parts_copy)
87                     if target in sys.modules:
88                         module = reload(sys.modules[target])
89                         break
90                     else:
91                         try:
92                             module = __import__(target)
93                             break
94                         except ImportError:
95                             del parts_copy[-1]
96                             if not parts_copy:
97                                 raise
98                 parts = parts[1:]
99         obj = module
100         for part in parts:
101             obj = getattr(obj, part)
102        
103         if type(obj) == types.ModuleType:
104             return self.loadTestsFromModule(obj)
105         elif (isinstance(obj, (type, types.ClassType)) and
106               issubclass(obj, TestCase)):
107             return self.loadTestsFromTestCase(obj)
108         elif type(obj) == types.UnboundMethodType:
109             return obj.im_class(obj.__name__)
110         elif callable(obj):
111             test = obj()
112             if not isinstance(test, TestCase) and \
113                not isinstance(test, TestSuite):
114                 raise ValueError("calling %s returned %s, "
115                                  "not a test" % (obj,test))
116             return test
117         else:
118             raise ValueError("do not know how to make test from: %s" % obj)
119
120
121 try:
122     # On Windows, msvcrt.getch reads a single char without output.
123     import msvcrt
124     def getchar():
125         return msvcrt.getch()
126 except ImportError:
127     # Unix getchr
128     import tty, termios
129     def getchar():
130         fd = sys.stdin.fileno()
131         old_settings = termios.tcgetattr(fd)
132         try:
133             tty.setraw(sys.stdin.fileno())
134             ch = sys.stdin.read(1)
135         finally:
136             termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
137         return ch
138
139
140 class WebCase(TestCase):
141     HOST = "127.0.0.1"
142     PORT = 8000
143     HTTP_CONN = httplib.HTTPConnection
144     PROTOCOL = "HTTP/1.1"
145    
146     scheme = "http"
147     url = None
148    
149     status = None
150     headers = None
151     body = None
152     time = None
153    
154     def set_persistent(self, on=True, auto_open=False):
155         """Make our HTTP_CONN persistent (or not).
156         
157         If the 'on' argument is True (the default), then self.HTTP_CONN
158         will be set to an instance of httplib.HTTPConnection (or HTTPS
159         if self.scheme is "https"). This will then persist across requests.
160         
161         We only allow for a single open connection, so if you call this
162         and we currently have an open connection, it will be closed.
163         """
164         try:
165             self.HTTP_CONN.close()
166         except (TypeError, AttributeError):
167             pass
168        
169         if self.scheme == "https":
170             cls = httplib.HTTPSConnection
171         else:
172             cls = httplib.HTTPConnection
173        
174         if on:
175             host = self.HOST
176             if host == '0.0.0.0':
177                 # INADDR_ANY, which should respond on localhost.
178                 host = "127.0.0.1"
179             elif host == '::':
180                 # IN6ADDR_ANY, which should respond on localhost.
181                 host = "::1"
182             self.HTTP_CONN = cls(host, self.PORT)
183             # Automatically re-connect?
184             self.HTTP_CONN.auto_open = auto_open
185             self.HTTP_CONN.connect()
186         else:
187             self.HTTP_CONN = cls
188    
189     def _get_persistent(self):
190         return hasattr(self.HTTP_CONN, "__class__")
191     def _set_persistent(self, on):
192         self.set_persistent(on)
193     persistent = property(_get_persistent, _set_persistent)
194    
195     def interface(self):
196         """Return an IP address for a client connection.
197         
198         If the server is listening on '0.0.0.0' (INADDR_ANY)
199         or '::' (IN6ADDR_ANY), this will return the proper localhost."""
200         host = self.HOST
201         if host == '0.0.0.0':
202             # INADDR_ANY, which should respond on localhost.
203             return "127.0.0.1"
204         if host == '::':
205             # IN6ADDR_ANY, which should respond on localhost.
206             return "::1"
207         return host
208    
209     def getPage(self, url, headers=None, method="GET", body=None, protocol=None):
210         """Open the url with debugging support. Return status, headers, body."""
211         ServerError.on = False
212        
213         self.url = url
214         self.time = None
215         start = time.time()
216         result = openURL(url, headers, method, body, self.HOST, self.PORT,
217                          self.HTTP_CONN, protocol or self.PROTOCOL)
218         self.time = time.time() - start
219         self.status, self.headers, self.body = result
220        
221         # Build a list of request cookies from the previous response cookies.
222         self.cookies = [('Cookie', v) for k, v in self.headers
223                         if k.lower() == 'set-cookie']
224        
225         if ServerError.on:
226             raise ServerError()
227         return result
228    
229     interactive = True
230     console_height = 30
231    
232     def _handlewebError(self, msg):
233         print
234         print "    ERROR:", msg
235        
236         if not self.interactive:
237             raise self.failureException(msg)
238        
239         p = "    Show: [B]ody [H]eaders [S]tatus [U]RL; [I]gnore, [R]aise, or sys.e[X]it >> "
240         print p,
241         while True:
242             i = getchar().upper()
243             if i not in "BHSUIRX":
244                 continue
245             print i.upper()  # Also prints new line
246             if i == "B":
247                 for x, line in enumerate(self.body.splitlines()):
248                     if (x + 1) % self.console_height == 0:
249                         # The \r and comma should make the next line overwrite
250                         print "<-- More -->\r",
251                         m = getchar().lower()
252                         # Erase our "More" prompt
253                         print "            \r",
254                         if m == "q":
255                             break
256                     print line
257             elif i == "H":
258                 pprint.pprint(self.headers)
259             elif i == "S":
260                 print self.status
261             elif i == "U":
262                 print self.url
263             elif i == "I":
264                 # return without raising the normal exception
265                 return
266             elif i == "R":
267                 raise self.failureException(msg)
268             elif i == "X":
269                 self.exit()
270             print p,
271    
272     def exit(self):
273         sys.exit()
274    
275     if sys.version_info >= (2, 5):
276         def __call__(self, result=None):
277             if result is None:
278                 result = self.defaultTestResult()
279             result.startTest(self)
280             testMethod = getattr(self, self._testMethodName)
281             try:
282                 try:
283                     self.setUp()
284                 except (KeyboardInterrupt, SystemExit):
285                     raise
286                 except:
287                     result.addError(self, self._exc_info())
288                     return
289                
290                 ok = 0
291                 try:
292                     testMethod()
293                     ok = 1
294                 except self.failureException:
295                     result.addFailure(self, self._exc_info())
296                 except (KeyboardInterrupt, SystemExit):
297                     raise
298                 except:
299                     result.addError(self, self._exc_info())
300                
301                 try:
302                     self.tearDown()
303                 except (KeyboardInterrupt, SystemExit):
304                     raise
305                 except:
306                     result.addError(self, self._exc_info())
307                     ok = 0
308                 if ok:
309                     result.addSuccess(self)
310             finally:
311                 result.stopTest(self)
312     else:
313         def __call__(self, result=None):
314             if result is None:
315                 result = self.defaultTestResult()
316             result.startTest(self)
317             testMethod = getattr(self, self._TestCase__testMethodName)
318             try:
319                 try:
320                     self.setUp()
321                 except (KeyboardInterrupt, SystemExit):
322                     raise
323                 except:
324                     result.addError(self, self._TestCase__exc_info())
325                     return
326                
327                 ok = 0
328                 try:
329                     testMethod()
330                     ok = 1
331                 except self.failureException:
332                     result.addFailure(self, self._TestCase__exc_info())
333                 except (KeyboardInterrupt, SystemExit):
334                     raise
335                 except:
336                     result.addError(self, self._TestCase__exc_info())
337                
338                 try:
339                     self.tearDown()
340                 except (KeyboardInterrupt, SystemExit):
341                     raise
342                 except:
343                     result.addError(self, self._TestCase__exc_info())
344                     ok = 0
345                 if ok:
346                     result.addSuccess(self)
347             finally:
348                 result.stopTest(self)
349    
350     def assertStatus(self, status, msg=None):
351         """Fail if self.status != status."""
352         if isinstance(status, basestring):
353             if not self.status == status:
354                 if msg is None:
355                     msg = 'Status (%r) != %r' % (self.status, status)
356                 self._handlewebError(msg)
357         elif isinstance(status, int):
358             code = int(self.status[:3])
359             if code != status:
360                 if msg is None:
361                     msg = 'Status (%r) != %r' % (self.status, status)
362                 self._handlewebError(msg)
363         else:
364             # status is a tuple or list.
365             match = False
366             for s in status:
367                 if isinstance(s, basestring):
368                     if self.status == s:
369                         match = True
370                         break
371                 elif int(self.status[:3]) == s:
372                     match = True
373                     break
374             if not match:
375                 if msg is None:
376                     msg = 'Status (%r) not in %r' % (self.status, status)
377                 self._handlewebError(msg)
378    
379     def assertHeader(self, key, value=None, msg=None):
380         """Fail if (key, [value]) not in self.headers."""
381         lowkey = key.lower()
382         for k, v in self.headers:
383             if k.lower() == lowkey:
384                 if value is None or str(value) == v:
385                     return v
386        
387         if msg is None:
388             if value is None:
389                 msg = '%r not in headers' % key
390             else:
391                 msg = '%r:%r not in headers' % (key, value)
392         self._handlewebError(msg)
393    
394     def assertNoHeader(self, key, msg=None):
395         """Fail if key in self.headers."""
396         lowkey = key.lower()
397         matches = [k for k, v in self.headers if k.lower() == lowkey]
398         if matches:
399             if msg is None:
400                 msg = '%r in headers' % key
401             self._handlewebError(msg)
402    
403     def assertBody(self, value, msg=None):
404         """Fail if value != self.body."""
405         if value != self.body:
406             if msg is None:
407                 msg = 'expected body:\n%r\n\nactual body:\n%r' % (value, self.body)
408             self._handlewebError(msg)
409    
410     def assertInBody(self, value, msg=None):
411         """Fail if value not in self.body."""
412         if value not in self.body:
413             if msg is None:
414                 msg = '%r not in body' % value
415             self._handlewebError(msg)
416    
417     def assertNotInBody(self, value, msg=None):
418         """Fail if value in self.body."""
419         if value in self.body:
420             if msg is None:
421                 msg = '%r found in body' % value
422             self._handlewebError(msg)
423    
424     def assertMatchesBody(self, pattern, msg=None, flags=0):
425         """Fail if value (a regex pattern) is not in self.body."""
426         if re.search(pattern, self.body, flags) is None:
427             if msg is None:
428                 msg = 'No match for %r in body' % pattern
429             self._handlewebError(msg)
430
431
432 methods_with_bodies = ("POST", "PUT")
433
434 def cleanHeaders(headers, method, body, host, port):
435     """Return request headers, with required headers added (if missing)."""
436     if headers is None:
437         headers = []
438    
439     # Add the required Host request header if not present.
440     # [This specifies the host:port of the server, not the client.]
441     found = False
442     for k, v in headers:
443         if k.lower() == 'host':
444             found = True
445             break
446     if not found:
447         if port == 80:
448             headers.append(("Host", host))
449         else:
450             headers.append(("Host", "%s:%s" % (host, port)))
451    
452     if method in methods_with_bodies:
453         # Stick in default type and length headers if not present
454         found = False
455         for k, v in headers:
456             if k.lower() == 'content-type':
457                 found = True
458                 break
459         if not found:
460             headers.append(("Content-Type", "application/x-www-form-urlencoded"))
461             headers.append(("Content-Length", str(len(body or ""))))
462    
463     return headers
464
465
466 def shb(response):
467     """Return status, headers, body the way we like from a response."""
468     h = []
469     key, value = None, None
470     for line in response.msg.headers:
471         if line:
472             if line[0] in " \t":
473                 value += line.strip()
474             else:
475                 if key and value:
476                     h.append((key, value))
477                 key, value = line.split(":", 1)
478                 key = key.strip()
479                 value = value.strip()
480     if key and value:
481         h.append((key, value))
482    
483     return "%s %s" % (response.status, response.reason), h, response.read()
484