Download Install Tutorial Docs FAQ Tools WikiLicense Team IRC Planet Involvement Shop Book

root/trunk/cherrypy/test/webtest.py

Revision 2637 (checked in by jtate, 3 weeks ago)

Convert the tests to use nose instead of our own runner. This strips out much coverage and profiling (handled by nose) and lets you focus on writing tests.

The biggest changes that have to be done in the tests classes is you have to put the "setup_server" method on the class(es) that need them when running. If you need it for multiple classes, you can use staticmethod() to attach it to multiple classes without using inheritance.

  • Property svn:eol-style set to native
Line 
1 """Extensions to unittest for web frameworks.
2
3 Use the WebCase.getPage method to request a page from your HTTP server.
4
5 Framework Integration
6 =====================
7
8 If you have control over your server process, you can handle errors
9 in the server-side of the HTTP conversation a bit better. You must run
10 both the client (your WebCase tests) and the server in the same process
11 (but in separate threads, obviously).
12
13 When an error occurs in the framework, call server_error. It will print
14 the traceback to stdout, and keep any assertions you have from running
15 (the assumption is that, if the server errors, the page output will not
16 be of further significance to your tests).
17 """
18
19 from httplib import HTTPConnection, HTTPSConnection
20 import os
21 import pprint
22 import re
23 import socket
24 import sys
25 import time
26 import traceback
27 import types
28
29 from unittest import *
30 from unittest import _TextTestResult
31
32
33
34 def interface(host):
35     """Return an IP address for a client connection given the server host.
36
37     If the server is listening on '0.0.0.0' (INADDR_ANY)
38     or '::' (IN6ADDR_ANY), this will return the proper localhost."""
39     if host == '0.0.0.0':
40         # INADDR_ANY, which should respond on localhost.
41         return "127.0.0.1"
42     if host == '::':
43         # IN6ADDR_ANY, which should respond on localhost.
44         return "::1"
45     return host
46
47
48 class TerseTestResult(_TextTestResult):
49
50     def printErrors(self):
51         # Overridden to avoid unnecessary empty line
52         if self.errors or self.failures:
53             if self.dots or self.showAll:
54                 self.stream.writeln()
55             self.printErrorList('ERROR', self.errors)
56             self.printErrorList('FAIL', self.failures)
57
58
59 class TerseTestRunner(TextTestRunner):
60     """A test runner class that displays results in textual form."""
61
62     def _makeResult(self):
63         return TerseTestResult(self.stream, self.descriptions, self.verbosity)
64
65     def run(self, test):
66         "Run the given test case or test suite."
67         # Overridden to remove unnecessary empty lines and separators
68         result = self._makeResult()
69         test(result)
70         result.printErrors()
71         if not result.wasSuccessful():
72             self.stream.write("FAILED (")
73             failed, errored = map(len, (result.failures, result.errors))
74             if failed:
75                 self.stream.write("failures=%d" % failed)
76             if errored:
77                 if failed: self.stream.write(", ")
78                 self.stream.write("errors=%d" % errored)
79             self.stream.writeln(")")
80         return result
81
82
83 class ReloadingTestLoader(TestLoader):
84
85     def loadTestsFromName(self, name, module=None):
86         """Return a suite of all tests cases given a string specifier.
87
88         The name may resolve either to a module, a test case class, a
89         test method within a test case class, or a callable object which
90         returns a TestCase or TestSuite instance.
91
92         The method optionally resolves the names relative to a given module.
93         """
94         parts = name.split('.')
95         unused_parts = []
96         if module is None:
97             if not parts:
98                 raise ValueError("incomplete test name: %s" % name)
99             else:
100                 parts_copy = parts[:]
101                 while parts_copy:
102                     target = ".".join(parts_copy)
103                     if target in sys.modules:
104                         module = reload(sys.modules[target])
105                         parts = unused_parts
106                         break
107                     else:
108                         try:
109                             module = __import__(target)
110                             parts = unused_parts
111                             break
112                         except ImportError:
113                             unused_parts.insert(0,parts_copy[-1])
114                             del parts_copy[-1]
115                             if not parts_copy:
116                                 raise
117                 parts = parts[1:]
118         obj = module
119         for part in parts:
120             obj = getattr(obj, part)
121
122         if type(obj) == types.ModuleType:
123             return self.loadTestsFromModule(obj)
124         elif (isinstance(obj, (type, types.ClassType)) and
125               issubclass(obj, TestCase)):
126             return self.loadTestsFromTestCase(obj)
127         elif type(obj) == types.UnboundMethodType:
128             return obj.im_class(obj.__name__)
129         elif callable(obj):
130             test = obj()
131             if not isinstance(test, TestCase) and \
132                not isinstance(test, TestSuite):
133                 raise ValueError("calling %s returned %s, "
134                                  "not a test" % (obj,test))
135             return test
136         else:
137             raise ValueError("do not know how to make test from: %s" % obj)
138
139
140 try:
141     # On Windows, msvcrt.getch reads a single char without output.
142     import msvcrt
143     def getchar():
144         return msvcrt.getch()
145 except ImportError:
146     # Unix getchr
147     import tty, termios
148     def getchar():
149         fd = sys.stdin.fileno()
150         old_settings = termios.tcgetattr(fd)
151         try:
152             tty.setraw(sys.stdin.fileno())
153             ch = sys.stdin.read(1)
154         finally:
155             termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
156         return ch
157
158
159 class WebCase(TestCase):
160     HOST = "127.0.0.1"
161     PORT = 8000
162     HTTP_CONN = HTTPConnection
163     PROTOCOL = "HTTP/1.1"
164
165     scheme = "http"
166     url = None
167
168     status = None
169     headers = None
170     body = None
171     time = None
172
173     def get_conn(self, auto_open=False):
174         """Return a connection to our HTTP server."""
175         if self.scheme == "https":
176             cls = HTTPSConnection
177         else:
178             cls = HTTPConnection
179         conn = cls(self.interface(), self.PORT)
180         # Automatically re-connect?
181         conn.auto_open = auto_open
182         conn.connect()
183         return conn
184
185     def set_persistent(self, on=True, auto_open=False):
186         """Make our HTTP_CONN persistent (or not).
187
188         If the 'on' argument is True (the default), then self.HTTP_CONN
189         will be set to an instance of HTTPConnection (or HTTPS
190         if self.scheme is "https"). This will then persist across requests.
191
192         We only allow for a single open connection, so if you call this
193         and we currently have an open connection, it will be closed.
194         """
195         try:
196             self.HTTP_CONN.close()
197         except (TypeError, AttributeError):
198             pass
199
200         if on:
201             self.HTTP_CONN = self.get_conn(auto_open=auto_open)
202         else:
203             if self.scheme == "https":
204                 self.HTTP_CONN = HTTPSConnection
205             else:
206                 self.HTTP_CONN = HTTPConnection
207
208     def _get_persistent(self):
209         return hasattr(self.HTTP_CONN, "__class__")
210     def _set_persistent(self, on):
211         self.set_persistent(on)
212     persistent = property(_get_persistent, _set_persistent)
213
214     def interface(self):
215         """Return an IP address for a client connection.
216
217         If the server is listening on '0.0.0.0' (INADDR_ANY)
218         or '::' (IN6ADDR_ANY), this will return the proper localhost."""
219         return interface(self.HOST)
220
221     def getPage(self, url, headers=None, method="GET", body=None, protocol=None):
222         """Open the url with debugging support. Return status, headers, body."""
223         ServerError.on = False
224
225         self.url = url
226         self.time = None
227         start = time.time()
228         result = openURL(url, headers, method, body, self.HOST, self.PORT,
229                          self.HTTP_CONN, protocol or self.PROTOCOL)
230         self.time = time.time() - start
231         self.status, self.headers, self.body = result
232
233         # Build a list of request cookies from the previous response cookies.
234         self.cookies = [('Cookie', v) for k, v in self.headers
235                         if k.lower() == 'set-cookie']
236
237         if ServerError.on:
238             raise ServerError()
239         return result
240
241     interactive = True
242     console_height = 30
243
244     def _handlewebError(self, msg):
245         import cherrypy
246         print("")
247         print("    ERROR: %s" % msg)
248
249         if not self.interactive:
250             raise self.failureException(msg)
251
252         p = "    Show: [B]ody [H]eaders [S]tatus [U]RL; [I]gnore, [R]aise, or sys.e[X]it >> "
253         print p,
254         # ARGH!
255         sys.stdout.flush()
256         while True:
257             i = getchar().upper()
258             if i not in "BHSUIRX":
259                 continue
260             print(i.upper())  # Also prints new line
261             if i == "B":
262                 for x, line in enumerate(self.body.splitlines()):
263                     if (x + 1) % self.console_height == 0:
264                         # The \r and comma should make the next line overwrite
265                         print "<-- More -->\r",
266                         m = getchar().lower()
267                         # Erase our "More" prompt
268                         print "            \r",
269                         if m == "q":
270                             break
271                     print(line)
272             elif i == "H":
273                 pprint.pprint(self.headers)
274             elif i == "S":
275                 print(self.status)
276             elif i == "U":
277                 print(self.url)
278             elif i == "I":
279                 # return without raising the normal exception
280                 return
281             elif i == "R":
282                 raise self.failureException(msg)
283             elif i == "X":
284                 self.exit()
285             print p,
286             # ARGH
287             sys.stdout.flush()
288
289     def exit(self):
290         sys.exit()
291
292     def assertStatus(self, status, msg=None):
293         """Fail if self.status != status."""
294         if isinstance(status, basestring):
295             if not self.status == status:
296                 if msg is None:
297                     msg = 'Status (%r) != %r' % (self.status, status)
298                 self._handlewebError(msg)
299         elif isinstance(status, int):
300             code = int(self.status[:3])
301             if code != status:
302                 if msg is None:
303                     msg = 'Status (%r) != %r' % (self.status, status)
304                 self._handlewebError(msg)
305         else:
306             # status is a tuple or list.
307             match = False
308             for s in status:
309                 if isinstance(s, basestring):
310                     if self.status == s:
311                         match = True
312                         break
313                 elif int(self.status[:3]) == s:
314                     match = True
315                     break
316             if not match:
317                 if msg is None:
318                     msg = 'Status (%r) not in %r' % (self.status, status)
319                 self._handlewebError(msg)
320
321     def assertHeader(self, key, value=None, msg=None):
322         """Fail if (key, [value]) not in self.headers."""
323         lowkey = key.lower()
324         for k, v in self.headers:
325             if k.lower() == lowkey:
326                 if value is None or str(value) == v:
327                     return v
328
329         if msg is None:
330             if value is None:
331                 msg = '%r not in headers' % key
332             else:
333                 msg = '%r:%r not in headers' % (key, value)
334         self._handlewebError(msg)
335
336     def assertHeaderItemValue(self, key, value, msg=None):
337         """Fail if the header does not contain the specified value"""
338         actual_value = self.assertHeader(key, msg=msg)
339         header_values = map(str.strip, actual_value.split(','))
340         if value in header_values:
341             return value
342
343         if msg is None:
344             msg = "%r not in %r" % (value, header_values)
345         self._handlewebError(msg)
346
347     def assertNoHeader(self, key, msg=None):
348         """Fail if key in self.headers."""
349         lowkey = key.lower()
350         matches = [k for k, v in self.headers if k.lower() == lowkey]
351         if matches:
352             if msg is None:
353                 msg = '%r in headers' % key
354             self._handlewebError(msg)
355
356     def assertBody(self, value, msg=None):
357         """Fail if value != self.body."""
358         if value != self.body:
359             if msg is None:
360                 msg = 'expected body:\n%r\n\nactual body:\n%r' % (value, self.body)
361             self._handlewebError(msg)
362
363     def assertInBody(self, value, msg=None):
364         """Fail if value not in self.body."""
365         if value not in self.body:
366             if msg is None:
367                 msg = '%r not in body: %s' % (value, self.body)
368             self._handlewebError(msg)
369
370     def assertNotInBody(self, value, msg=None):
371         """Fail if value in self.body."""
372         if value in self.body:
373             if msg is None:
374                 msg = '%r found in body' % value
375             self._handlewebError(msg)
376
377     def assertMatchesBody(self, pattern, msg=None, flags=0):
378         """Fail if value (a regex pattern) is not in self.body."""
379         if re.search(pattern, self.body, flags) is None:
380             if msg is None:
381                 msg = 'No match for %r in body' % pattern
382             self._handlewebError(msg)
383
384
385 methods_with_bodies = ("POST", "PUT")
386
387 def cleanHeaders(headers, method, body, host, port):
388     """Return request headers, with required headers added (if missing)."""
389     if headers is None:
390         headers = []
391
392     # Add the required Host request header if not present.
393     # [This specifies the host:port of the server, not the client.]
394     found = False
395     for k, v in headers:
396         if k.lower() == 'host':
397             found = True
398             break
399     if not found:
400         if port == 80:
401             headers.append(("Host", host))
402         else:
403             headers.append(("Host", "%s:%s" % (host, port)))
404
405     if method in methods_with_bodies:
406         # Stick in default type and length headers if not present
407         found = False
408         for k, v in headers:
409             if k.lower() == 'content-type':
410                 found = True
411                 break
412         if not found:
413             headers.append(("Content-Type", "application/x-www-form-urlencoded"))
414             headers.append(("Content-Length", str(len(body or ""))))
415
416     return headers
417
418
419 def shb(response):
420     """Return status, headers, body the way we like from a response."""
421     h = []
422     key, value = None, None
423     for line in response.msg.headers:
424         if line:
425             if line[0] in " \t":
426                 value += line.strip()
427             else:
428                 if key and value:
429                     h.append((key, value))
430                 key, value = line.split(":", 1)
431                 key = key.strip()
432                 value = value.strip()
433     if key and value:
434         h.append((key, value))
435
436     return "%s %s" % (response.status, response.reason), h, response.read()
437
438
439 def openURL(url, headers=None, method="GET", body=None,
440             host="127.0.0.1", port=8000, http_conn=HTTPConnection,
441             protocol="HTTP/1.1"):
442     """Open the given HTTP resource and return status, headers, and body."""
443
444     headers = cleanHeaders(headers, method, body, host, port)
445
446     # Trying 10 times is simply in case of socket errors.
447     # Normal case--it should run once.
448     for trial in range(10):
449         try:
450             # Allow http_conn to be a class or an instance
451             if hasattr(http_conn, "host"):
452                 conn = http_conn
453             else:
454                 conn = http_conn(interface(host), port)
455
456             conn._http_vsn_str = protocol
457             conn._http_vsn = int("".join([x for x in protocol if x.isdigit()]))
458
459             # skip_accept_encoding argument added in python version 2.4
460             if sys.version_info < (2, 4):
461                 def putheader(self, header, value):
462                     if header == 'Accept-Encoding' and value == 'identity':
463                         return
464                     self.__class__.putheader(self, header, value)
465                 import new
466                 conn.putheader = new.instancemethod(putheader, conn, conn.__class__)
467                 conn.putrequest(method.upper(), url, skip_host=True)
468             else:
469                 conn.putrequest(method.upper(), url, skip_host=True,
470                                 skip_accept_encoding=True)
471
472             for key, value in headers:
473                 conn.putheader(key, value)
474             conn.endheaders()
475
476             if body is not None:
477                 conn.send(body)
478
479             # Handle response
480             response = conn.getresponse()
481
482             s, h, b = shb(response)
483
484             if not hasattr(http_conn, "host"):
485                 # We made our own conn instance. Close it.
486                 conn.close()
487
488             return s, h, b
489         except socket.error:
490             time.sleep(0.5)
491     raise
492
493
494 # Add any exceptions which your web framework handles
495 # normally (that you don't want server_error to trap).
496 ignored_exceptions = []
497
498 # You'll want set this to True when you can't guarantee
499 # that each response will immediately follow each request;
500 # for example, when handling requests via multiple threads.
501 ignore_all = False
502
503 class ServerError(Exception):
504     on = False
505
506
507 def server_error(exc=None):
508     """Server debug hook. Return True if exception handled, False if ignored.
509
510     You probably want to wrap this, so you can still handle an error using
511     your framework when it's ignored.
512     """
513     if exc is None:
514         exc = sys.exc_info()
515
516     if ignore_all or exc[0] in ignored_exceptions:
517         return False
518     else:
519         ServerError.on = True
520         print("")
521         print("".join(traceback.format_exception(*exc)))
522         return True
523
Note: See TracBrowser for help on using the browser.

Hosted by WebFaction

Log in as guest/cpguest to create tickets