Download Install Tutorial Docs FAQ Tools WikiLicense Team IRC Planet Involvement Shop Book

root/trunk/cherrypy/test/webtest.py

Revision 2668 (checked in by fumanchu, 4 months ago)

A couple test tweaks.

  • Property svn:eol-style set to native
Line 
1 """Extensions to unittest for web frameworks.
2
3 Use the WebCase.getPage method to request a page from your HTTP server.
4
5 Framework Integration
6 =====================
7
8 If you have control over your server process, you can handle errors
9 in the server-side of the HTTP conversation a bit better. You must run
10 both the client (your WebCase tests) and the server in the same process
11 (but in separate threads, obviously).
12
13 When an error occurs in the framework, call server_error. It will print
14 the traceback to stdout, and keep any assertions you have from running
15 (the assumption is that, if the server errors, the page output will not
16 be of further significance to your tests).
17 """
18
19 from httplib import HTTPConnection, HTTPSConnection
20 import os
21 import pprint
22 import re
23 import socket
24 import sys
25 import time
26 import traceback
27 import types
28
29 from unittest import *
30 from unittest import _TextTestResult
31
32
33
34 def interface(host):
35     """Return an IP address for a client connection given the server host.
36
37     If the server is listening on '0.0.0.0' (INADDR_ANY)
38     or '::' (IN6ADDR_ANY), this will return the proper localhost."""
39     if host == '0.0.0.0':
40         # INADDR_ANY, which should respond on localhost.
41         return "127.0.0.1"
42     if host == '::':
43         # IN6ADDR_ANY, which should respond on localhost.
44         return "::1"
45     return host
46
47
48 class TerseTestResult(_TextTestResult):
49
50     def printErrors(self):
51         # Overridden to avoid unnecessary empty line
52         if self.errors or self.failures:
53             if self.dots or self.showAll:
54                 self.stream.writeln()
55             self.printErrorList('ERROR', self.errors)
56             self.printErrorList('FAIL', self.failures)
57
58
59 class TerseTestRunner(TextTestRunner):
60     """A test runner class that displays results in textual form."""
61
62     def _makeResult(self):
63         return TerseTestResult(self.stream, self.descriptions, self.verbosity)
64
65     def run(self, test):
66         "Run the given test case or test suite."
67         # Overridden to remove unnecessary empty lines and separators
68         result = self._makeResult()
69         test(result)
70         result.printErrors()
71         if not result.wasSuccessful():
72             self.stream.write("FAILED (")
73             failed, errored = map(len, (result.failures, result.errors))
74             if failed:
75                 self.stream.write("failures=%d" % failed)
76             if errored:
77                 if failed: self.stream.write(", ")
78                 self.stream.write("errors=%d" % errored)
79             self.stream.writeln(")")
80         return result
81
82
83 class ReloadingTestLoader(TestLoader):
84
85     def loadTestsFromName(self, name, module=None):
86         """Return a suite of all tests cases given a string specifier.
87
88         The name may resolve either to a module, a test case class, a
89         test method within a test case class, or a callable object which
90         returns a TestCase or TestSuite instance.
91
92         The method optionally resolves the names relative to a given module.
93         """
94         parts = name.split('.')
95         unused_parts = []
96         if module is None:
97             if not parts:
98                 raise ValueError("incomplete test name: %s" % name)
99             else:
100                 parts_copy = parts[:]
101                 while parts_copy:
102                     target = ".".join(parts_copy)
103                     if target in sys.modules:
104                         module = reload(sys.modules[target])
105                         parts = unused_parts
106                         break
107                     else:
108                         try:
109                             module = __import__(target)
110                             parts = unused_parts
111                             break
112                         except ImportError:
113                             unused_parts.insert(0,parts_copy[-1])
114                             del parts_copy[-1]
115                             if not parts_copy:
116                                 raise
117                 parts = parts[1:]
118         obj = module
119         for part in parts:
120             obj = getattr(obj, part)
121
122         if type(obj) == types.ModuleType:
123             return self.loadTestsFromModule(obj)
124         elif (isinstance(obj, (type, types.ClassType)) and
125               issubclass(obj, TestCase)):
126             return self.loadTestsFromTestCase(obj)
127         elif type(obj) == types.UnboundMethodType:
128             return obj.im_class(obj.__name__)
129         elif callable(obj):
130             test = obj()
131             if not isinstance(test, TestCase) and \
132                not isinstance(test, TestSuite):
133                 raise ValueError("calling %s returned %s, "
134                                  "not a test" % (obj,test))
135             return test
136         else:
137             raise ValueError("do not know how to make test from: %s" % obj)
138
139
140 try:
141     # On Windows, msvcrt.getch reads a single char without output.
142     import msvcrt
143     def getchar():
144         return msvcrt.getch()
145 except ImportError:
146     # Unix getchr
147     import tty, termios
148     def getchar():
149         fd = sys.stdin.fileno()
150         old_settings = termios.tcgetattr(fd)
151         try:
152             tty.setraw(sys.stdin.fileno())
153             ch = sys.stdin.read(1)
154         finally:
155             termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
156         return ch
157
158
159 class WebCase(TestCase):
160     HOST = "127.0.0.1"
161     PORT = 8000
162     HTTP_CONN = HTTPConnection
163     PROTOCOL = "HTTP/1.1"
164
165     scheme = "http"
166     url = None
167
168     status = None
169     headers = None
170     body = None
171     time = None
172
173     def get_conn(self, auto_open=False):
174         """Return a connection to our HTTP server."""
175         if self.scheme == "https":
176             cls = HTTPSConnection
177         else:
178             cls = HTTPConnection
179         conn = cls(self.interface(), self.PORT)
180         # Automatically re-connect?
181         conn.auto_open = auto_open
182         conn.connect()
183         return conn
184
185     def set_persistent(self, on=True, auto_open=False):
186         """Make our HTTP_CONN persistent (or not).
187
188         If the 'on' argument is True (the default), then self.HTTP_CONN
189         will be set to an instance of HTTPConnection (or HTTPS
190         if self.scheme is "https"). This will then persist across requests.
191
192         We only allow for a single open connection, so if you call this
193         and we currently have an open connection, it will be closed.
194         """
195         try:
196             self.HTTP_CONN.close()
197         except (TypeError, AttributeError):
198             pass
199
200         if on:
201             self.HTTP_CONN = self.get_conn(auto_open=auto_open)
202         else:
203             if self.scheme == "https":
204                 self.HTTP_CONN = HTTPSConnection
205             else:
206                 self.HTTP_CONN = HTTPConnection
207
208     def _get_persistent(self):
209         return hasattr(self.HTTP_CONN, "__class__")
210     def _set_persistent(self, on):
211         self.set_persistent(on)
212     persistent = property(_get_persistent, _set_persistent)
213
214     def interface(self):
215         """Return an IP address for a client connection.
216
217         If the server is listening on '0.0.0.0' (INADDR_ANY)
218         or '::' (IN6ADDR_ANY), this will return the proper localhost."""
219         return interface(self.HOST)
220
221     def getPage(self, url, headers=None, method="GET", body=None, protocol=None):
222         """Open the url with debugging support. Return status, headers, body."""
223         ServerError.on = False
224
225         self.url = url
226         self.time = None
227         start = time.time()
228         result = openURL(url, headers, method, body, self.HOST, self.PORT,
229                          self.HTTP_CONN, protocol or self.PROTOCOL)
230         self.time = time.time() - start
231         self.status, self.headers, self.body = result
232
233         # Build a list of request cookies from the previous response cookies.
234         self.cookies = [('Cookie', v) for k, v in self.headers
235                         if k.lower() == 'set-cookie']
236
237         if ServerError.on:
238             raise ServerError()
239         return result
240
241     interactive = True
242     console_height = 30
243
244     def _handlewebError(self, msg):
245         print("")
246         print("    ERROR: %s" % msg)
247
248         if not self.interactive:
249             raise self.failureException(msg)
250
251         p = "    Show: [B]ody [H]eaders [S]tatus [U]RL; [I]gnore, [R]aise, or sys.e[X]it >> "
252         print p,
253         # ARGH!
254         sys.stdout.flush()
255         while True:
256             i = getchar().upper()
257             if i not in "BHSUIRX":
258                 continue
259             print(i.upper())  # Also prints new line
260             if i == "B":
261                 for x, line in enumerate(self.body.splitlines()):
262                     if (x + 1) % self.console_height == 0:
263                         # The \r and comma should make the next line overwrite
264                         print "<-- More -->\r",
265                         m = getchar().lower()
266                         # Erase our "More" prompt
267                         print "            \r",
268                         if m == "q":
269                             break
270                     print(line)
271             elif i == "H":
272                 pprint.pprint(self.headers)
273             elif i == "S":
274                 print(self.status)
275             elif i == "U":
276                 print(self.url)
277             elif i == "I":
278                 # return without raising the normal exception
279                 return
280             elif i == "R":
281                 raise self.failureException(msg)
282             elif i == "X":
283                 self.exit()
284             print p,
285             # ARGH
286             sys.stdout.flush()
287
288     def exit(self):
289         sys.exit()
290
291     def assertStatus(self, status, msg=None):
292         """Fail if self.status != status."""
293         if isinstance(status, basestring):
294             if not self.status == status:
295                 if msg is None:
296                     msg = 'Status (%r) != %r' % (self.status, status)
297                 self._handlewebError(msg)
298         elif isinstance(status, int):
299             code = int(self.status[:3])
300             if code != status:
301                 if msg is None:
302                     msg = 'Status (%r) != %r' % (self.status, status)
303                 self._handlewebError(msg)
304         else:
305             # status is a tuple or list.
306             match = False
307             for s in status:
308                 if isinstance(s, basestring):
309                     if self.status == s:
310                         match = True
311                         break
312                 elif int(self.status[:3]) == s:
313                     match = True
314                     break
315             if not match:
316                 if msg is None:
317                     msg = 'Status (%r) not in %r' % (self.status, status)
318                 self._handlewebError(msg)
319
320     def assertHeader(self, key, value=None, msg=None):
321         """Fail if (key, [value]) not in self.headers."""
322         lowkey = key.lower()
323         for k, v in self.headers:
324             if k.lower() == lowkey:
325                 if value is None or str(value) == v:
326                     return v
327
328         if msg is None:
329             if value is None:
330                 msg = '%r not in headers' % key
331             else:
332                 msg = '%r:%r not in headers' % (key, value)
333         self._handlewebError(msg)
334
335     def assertHeaderItemValue(self, key, value, msg=None):
336         """Fail if the header does not contain the specified value"""
337         actual_value = self.assertHeader(key, msg=msg)
338         header_values = map(str.strip, actual_value.split(','))
339         if value in header_values:
340             return value
341
342         if msg is None:
343             msg = "%r not in %r" % (value, header_values)
344         self._handlewebError(msg)
345
346     def assertNoHeader(self, key, msg=None):
347         """Fail if key in self.headers."""
348         lowkey = key.lower()
349         matches = [k for k, v in self.headers if k.lower() == lowkey]
350         if matches:
351             if msg is None:
352                 msg = '%r in headers' % key
353             self._handlewebError(msg)
354
355     def assertBody(self, value, msg=None):
356         """Fail if value != self.body."""
357         if value != self.body:
358             if msg is None:
359                 msg = 'expected body:\n%r\n\nactual body:\n%r' % (value, self.body)
360             self._handlewebError(msg)
361
362     def assertInBody(self, value, msg=None):
363         """Fail if value not in self.body."""
364         if value not in self.body:
365             if msg is None:
366                 msg = '%r not in body: %s' % (value, self.body)
367             self._handlewebError(msg)
368
369     def assertNotInBody(self, value, msg=None):
370         """Fail if value in self.body."""
371         if value in self.body:
372             if msg is None:
373                 msg = '%r found in body' % value
374             self._handlewebError(msg)
375
376     def assertMatchesBody(self, pattern, msg=None, flags=0):
377         """Fail if value (a regex pattern) is not in self.body."""
378         if re.search(pattern, self.body, flags) is None:
379             if msg is None:
380                 msg = 'No match for %r in body' % pattern
381             self._handlewebError(msg)
382
383
384 methods_with_bodies = ("POST", "PUT")
385
386 def cleanHeaders(headers, method, body, host, port):
387     """Return request headers, with required headers added (if missing)."""
388     if headers is None:
389         headers = []
390
391     # Add the required Host request header if not present.
392     # [This specifies the host:port of the server, not the client.]
393     found = False
394     for k, v in headers:
395         if k.lower() == 'host':
396             found = True
397             break
398     if not found:
399         if port == 80:
400             headers.append(("Host", host))
401         else:
402             headers.append(("Host", "%s:%s" % (host, port)))
403
404     if method in methods_with_bodies:
405         # Stick in default type and length headers if not present
406         found = False
407         for k, v in headers:
408             if k.lower() == 'content-type':
409                 found = True
410                 break
411         if not found:
412             headers.append(("Content-Type", "application/x-www-form-urlencoded"))
413             headers.append(("Content-Length", str(len(body or ""))))
414
415     return headers
416
417
418 def shb(response):
419     """Return status, headers, body the way we like from a response."""
420     h = []
421     key, value = None, None
422     for line in response.msg.headers:
423         if line:
424             if line[0] in " \t":
425                 value += line.strip()
426             else:
427                 if key and value:
428                     h.append((key, value))
429                 key, value = line.split(":", 1)
430                 key = key.strip()
431                 value = value.strip()
432     if key and value:
433         h.append((key, value))
434
435     return "%s %s" % (response.status, response.reason), h, response.read()
436
437
438 def openURL(url, headers=None, method="GET", body=None,
439             host="127.0.0.1", port=8000, http_conn=HTTPConnection,
440             protocol="HTTP/1.1"):
441     """Open the given HTTP resource and return status, headers, and body."""
442
443     headers = cleanHeaders(headers, method, body, host, port)
444
445     # Trying 10 times is simply in case of socket errors.
446     # Normal case--it should run once.
447     for trial in range(10):
448         try:
449             # Allow http_conn to be a class or an instance
450             if hasattr(http_conn, "host"):
451                 conn = http_conn
452             else:
453                 conn = http_conn(interface(host), port)
454
455             conn._http_vsn_str = protocol
456             conn._http_vsn = int("".join([x for x in protocol if x.isdigit()]))
457
458             # skip_accept_encoding argument added in python version 2.4
459             if sys.version_info < (2, 4):
460                 def putheader(self, header, value):
461                     if header == 'Accept-Encoding' and value == 'identity':
462                         return
463                     self.__class__.putheader(self, header, value)
464                 import new
465                 conn.putheader = new.instancemethod(putheader, conn, conn.__class__)
466                 conn.putrequest(method.upper(), url, skip_host=True)
467             else:
468                 conn.putrequest(method.upper(), url, skip_host=True,
469                                 skip_accept_encoding=True)
470
471             for key, value in headers:
472                 conn.putheader(key, value)
473             conn.endheaders()
474
475             if body is not None:
476                 conn.send(body)
477
478             # Handle response
479             response = conn.getresponse()
480
481             s, h, b = shb(response)
482
483             if not hasattr(http_conn, "host"):
484                 # We made our own conn instance. Close it.
485                 conn.close()
486
487             return s, h, b
488         except socket.error:
489             time.sleep(0.5)
490     raise
491
492
493 # Add any exceptions which your web framework handles
494 # normally (that you don't want server_error to trap).
495 ignored_exceptions = []
496
497 # You'll want set this to True when you can't guarantee
498 # that each response will immediately follow each request;
499 # for example, when handling requests via multiple threads.
500 ignore_all = False
501
502 class ServerError(Exception):
503     on = False
504
505
506 def server_error(exc=None):
507     """Server debug hook. Return True if exception handled, False if ignored.
508
509     You probably want to wrap this, so you can still handle an error using
510     your framework when it's ignored.
511     """
512     if exc is None:
513         exc = sys.exc_info()
514
515     if ignore_all or exc[0] in ignored_exceptions:
516         return False
517     else:
518         ServerError.on = True
519         print("")
520         print("".join(traceback.format_exception(*exc)))
521         return True
522
Note: See TracBrowser for help on using the browser.

Hosted by WebFaction

Log in as guest/cpguest to create tickets