Download Install Tutorial Docs FAQ Tools WikiLicense Team IRC Planet Involvement Shop Book

root/trunk/cherrypy/test/webtest.py

Revision 2437 (checked in by fumanchu, 3 weeks ago)

Removed py3print.

  • Property svn:eol-style set to native
Line 
1 """Extensions to unittest for web frameworks.
2
3 Use the WebCase.getPage method to request a page from your HTTP server.
4
5 Framework Integration
6 =====================
7
8 If you have control over your server process, you can handle errors
9 in the server-side of the HTTP conversation a bit better. You must run
10 both the client (your WebCase tests) and the server in the same process
11 (but in separate threads, obviously).
12
13 When an error occurs in the framework, call server_error. It will print
14 the traceback to stdout, and keep any assertions you have from running
15 (the assumption is that, if the server errors, the page output will not
16 be of further significance to your tests).
17 """
18
19 import httplib
20 import os
21 import pprint
22 import re
23 import socket
24 import sys
25 import time
26 import traceback
27 import types
28
29 from unittest import *
30 from unittest import _TextTestResult
31
32
33
34 def interface(host):
35     """Return an IP address for a client connection given the server host.
36     
37     If the server is listening on '0.0.0.0' (INADDR_ANY)
38     or '::' (IN6ADDR_ANY), this will return the proper localhost."""
39     if host == '0.0.0.0':
40         # INADDR_ANY, which should respond on localhost.
41         return "127.0.0.1"
42     if host == '::':
43         # IN6ADDR_ANY, which should respond on localhost.
44         return "::1"
45     return host
46
47
48 class TerseTestResult(_TextTestResult):
49    
50     def printErrors(self):
51         # Overridden to avoid unnecessary empty line
52         if self.errors or self.failures:
53             if self.dots or self.showAll:
54                 self.stream.writeln()
55             self.printErrorList('ERROR', self.errors)
56             self.printErrorList('FAIL', self.failures)
57
58
59 class TerseTestRunner(TextTestRunner):
60     """A test runner class that displays results in textual form."""
61    
62     def _makeResult(self):
63         return TerseTestResult(self.stream, self.descriptions, self.verbosity)
64    
65     def run(self, test):
66         "Run the given test case or test suite."
67         # Overridden to remove unnecessary empty lines and separators
68         result = self._makeResult()
69         test(result)
70         result.printErrors()
71         if not result.wasSuccessful():
72             self.stream.write("FAILED (")
73             failed, errored = map(len, (result.failures, result.errors))
74             if failed:
75                 self.stream.write("failures=%d" % failed)
76             if errored:
77                 if failed: self.stream.write(", ")
78                 self.stream.write("errors=%d" % errored)
79             self.stream.writeln(")")
80         return result
81
82
83 class ReloadingTestLoader(TestLoader):
84    
85     def loadTestsFromName(self, name, module=None):
86         """Return a suite of all tests cases given a string specifier.
87
88         The name may resolve either to a module, a test case class, a
89         test method within a test case class, or a callable object which
90         returns a TestCase or TestSuite instance.
91
92         The method optionally resolves the names relative to a given module.
93         """
94         parts = name.split('.')
95         unused_parts = []
96         if module is None:
97             if not parts:
98                 raise ValueError("incomplete test name: %s" % name)
99             else:
100                 parts_copy = parts[:]
101                 while parts_copy:
102                     target = ".".join(parts_copy)
103                     if target in sys.modules:
104                         module = reload(sys.modules[target])
105                         parts = unused_parts
106                         break
107                     else:
108                         try:
109                             module = __import__(target)
110                             parts = unused_parts
111                             break
112                         except ImportError:
113                             unused_parts.insert(0,parts_copy[-1])
114                             del parts_copy[-1]
115                             if not parts_copy:
116                                 raise
117                 parts = parts[1:]
118         obj = module
119         for part in parts:
120             obj = getattr(obj, part)
121        
122         if type(obj) == types.ModuleType:
123             return self.loadTestsFromModule(obj)
124         elif (isinstance(obj, (type, types.ClassType)) and
125               issubclass(obj, TestCase)):
126             return self.loadTestsFromTestCase(obj)
127         elif type(obj) == types.UnboundMethodType:
128             return obj.im_class(obj.__name__)
129         elif callable(obj):
130             test = obj()
131             if not isinstance(test, TestCase) and \
132                not isinstance(test, TestSuite):
133                 raise ValueError("calling %s returned %s, "
134                                  "not a test" % (obj,test))
135             return test
136         else:
137             raise ValueError("do not know how to make test from: %s" % obj)
138
139
140 try:
141     # On Windows, msvcrt.getch reads a single char without output.
142     import msvcrt
143     def getchar():
144         return msvcrt.getch()
145 except ImportError:
146     # Unix getchr
147     import tty, termios
148     def getchar():
149         fd = sys.stdin.fileno()
150         old_settings = termios.tcgetattr(fd)
151         try:
152             tty.setraw(sys.stdin.fileno())
153             ch = sys.stdin.read(1)
154         finally:
155             termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
156         return ch
157
158
159 class WebCase(TestCase):
160     HOST = "127.0.0.1"
161     PORT = 8000
162     HTTP_CONN = httplib.HTTPConnection
163     PROTOCOL = "HTTP/1.1"
164    
165     scheme = "http"
166     url = None
167    
168     status = None
169     headers = None
170     body = None
171     time = None
172    
173     def get_conn(self, auto_open=False):
174         """Return a connection to our HTTP server."""
175         if self.scheme == "https":
176             cls = httplib.HTTPSConnection
177         else:
178             cls = httplib.HTTPConnection
179         conn = cls(self.interface(), self.PORT)
180         # Automatically re-connect?
181         conn.auto_open = auto_open
182         conn.connect()
183         return conn
184    
185     def set_persistent(self, on=True, auto_open=False):
186         """Make our HTTP_CONN persistent (or not).
187         
188         If the 'on' argument is True (the default), then self.HTTP_CONN
189         will be set to an instance of httplib.HTTPConnection (or HTTPS
190         if self.scheme is "https"). This will then persist across requests.
191         
192         We only allow for a single open connection, so if you call this
193         and we currently have an open connection, it will be closed.
194         """
195         try:
196             self.HTTP_CONN.close()
197         except (TypeError, AttributeError):
198             pass
199        
200         if on:
201             self.HTTP_CONN = self.get_conn(auto_open=auto_open)
202         else:
203             if self.scheme == "https":
204                 self.HTTP_CONN = httplib.HTTPSConnection
205             else:
206                 self.HTTP_CONN = httplib.HTTPConnection
207    
208     def _get_persistent(self):
209         return hasattr(self.HTTP_CONN, "__class__")
210     def _set_persistent(self, on):
211         self.set_persistent(on)
212     persistent = property(_get_persistent, _set_persistent)
213    
214     def interface(self):
215         """Return an IP address for a client connection.
216         
217         If the server is listening on '0.0.0.0' (INADDR_ANY)
218         or '::' (IN6ADDR_ANY), this will return the proper localhost."""
219         return interface(self.HOST)
220    
221     def getPage(self, url, headers=None, method="GET", body=None, protocol=None):
222         """Open the url with debugging support. Return status, headers, body."""
223         ServerError.on = False
224        
225         self.url = url
226         self.time = None
227         start = time.time()
228         result = openURL(url, headers, method, body, self.HOST, self.PORT,
229                          self.HTTP_CONN, protocol or self.PROTOCOL)
230         self.time = time.time() - start
231         self.status, self.headers, self.body = result
232        
233         # Build a list of request cookies from the previous response cookies.
234         self.cookies = [('Cookie', v) for k, v in self.headers
235                         if k.lower() == 'set-cookie']
236        
237         if ServerError.on:
238             raise ServerError()
239         return result
240    
241     interactive = True
242     console_height = 30
243    
244     def _handlewebError(self, msg):
245         import cherrypy
246         print("")
247         print("    ERROR: %s" % msg)
248        
249         if not self.interactive:
250             raise self.failureException(msg)
251        
252         p = "    Show: [B]ody [H]eaders [S]tatus [U]RL; [I]gnore, [R]aise, or sys.e[X]it >> "
253         print p,
254         # ARGH!
255         sys.stdout.flush()
256         while True:
257             i = getchar().upper()
258             if i not in "BHSUIRX":
259                 continue
260             print(i.upper())  # Also prints new line
261             if i == "B":
262                 for x, line in enumerate(self.body.splitlines()):
263                     if (x + 1) % self.console_height == 0:
264                         # The \r and comma should make the next line overwrite
265                         print "<-- More -->\r",
266                         m = getchar().lower()
267                         # Erase our "More" prompt
268                         print "            \r",
269                         if m == "q":
270                             break
271                     print(line)
272             elif i == "H":
273                 pprint.pprint(self.headers)
274             elif i == "S":
275                 print(self.status)
276             elif i == "U":
277                 print(self.url)
278             elif i == "I":
279                 # return without raising the normal exception
280                 return
281             elif i == "R":
282                 raise self.failureException(msg)
283             elif i == "X":
284                 self.exit()
285             print p,
286             # ARGH
287             sys.stdout.flush()   
288     def exit(self):
289         sys.exit()
290    
291     if sys.version_info >= (2, 5):
292         def __call__(self, result=None):
293             if result is None:
294                 result = self.defaultTestResult()
295             result.startTest(self)
296             testMethod = getattr(self, self._testMethodName)
297             try:
298                 try:
299                     self.setUp()
300                 except (KeyboardInterrupt, SystemExit):
301                     raise
302                 except:
303                     result.addError(self, self._exc_info())
304                     return
305                
306                 ok = 0
307                 try:
308                     testMethod()
309                     ok = 1
310                 except self.failureException:
311                     result.addFailure(self, self._exc_info())
312                 except (KeyboardInterrupt, SystemExit):
313                     raise
314                 except:
315                     result.addError(self, self._exc_info())
316                
317                 try:
318                     self.tearDown()
319                 except (KeyboardInterrupt, SystemExit):
320                     raise
321                 except:
322                     result.addError(self, self._exc_info())
323                     ok = 0
324                 if ok:
325                     result.addSuccess(self)
326             finally:
327                 result.stopTest(self)
328     else:
329         def __call__(self, result=None):
330             if result is None:
331                 result = self.defaultTestResult()
332             result.startTest(self)
333             testMethod = getattr(self, self._TestCase__testMethodName)
334             try:
335                 try:
336                     self.setUp()
337                 except (KeyboardInterrupt, SystemExit):
338                     raise
339                 except:
340                     result.addError(self, self._TestCase__exc_info())
341                     return
342                
343                 ok = 0
344                 try:
345                     testMethod()
346                     ok = 1
347                 except self.failureException:
348                     result.addFailure(self, self._TestCase__exc_info())
349                 except (KeyboardInterrupt, SystemExit):
350                     raise
351                 except:
352                     result.addError(self, self._TestCase__exc_info())
353                
354                 try:
355                     self.tearDown()
356                 except (KeyboardInterrupt, SystemExit):
357                     raise
358                 except:
359                     result.addError(self, self._TestCase__exc_info())
360                     ok = 0
361                 if ok:
362                     result.addSuccess(self)
363             finally:
364                 result.stopTest(self)
365    
366     def assertStatus(self, status, msg=None):
367         """Fail if self.status != status."""
368         if isinstance(status, basestring):
369             if not self.status == status:
370                 if msg is None:
371                     msg = 'Status (%r) != %r' % (self.status, status)
372                 self._handlewebError(msg)
373         elif isinstance(status, int):
374             code = int(self.status[:3])
375             if code != status:
376                 if msg is None:
377                     msg = 'Status (%r) != %r' % (self.status, status)
378                 self._handlewebError(msg)
379         else:
380             # status is a tuple or list.
381             match = False
382             for s in status:
383                 if isinstance(s, basestring):
384                     if self.status == s:
385                         match = True
386                         break
387                 elif int(self.status[:3]) == s:
388                     match = True
389                     break
390             if not match:
391                 if msg is None:
392                     msg = 'Status (%r) not in %r' % (self.status, status)
393                 self._handlewebError(msg)
394    
395     def assertHeader(self, key, value=None, msg=None):
396         """Fail if (key, [value]) not in self.headers."""
397         lowkey = key.lower()
398         for k, v in self.headers:
399             if k.lower() == lowkey:
400                 if value is None or str(value) == v:
401                     return v
402        
403         if msg is None:
404             if value is None:
405                 msg = '%r not in headers' % key
406             else:
407                 msg = '%r:%r not in headers' % (key, value)
408         self._handlewebError(msg)
409    
410     def assertHeaderItemValue(self, key, value, msg=None):
411         """Fail if the header does not contain the specified value"""
412         actual_value = self.assertHeader(key, msg=msg)
413         header_values = map(str.strip, actual_value.split(','))
414         if value in header_values:
415             return value
416        
417         if msg is None:
418             msg = "%r not in %r" % (value, header_values)
419         self._handlewebError(msg)
420
421     def assertNoHeader(self, key, msg=None):
422         """Fail if key in self.headers."""
423         lowkey = key.lower()
424         matches = [k for k, v in self.headers if k.lower() == lowkey]
425         if matches:
426             if msg is None:
427                 msg = '%r in headers' % key
428             self._handlewebError(msg)
429    
430     def assertBody(self, value, msg=None):
431         """Fail if value != self.body."""
432         if value != self.body:
433             if msg is None:
434                 msg = 'expected body:\n%r\n\nactual body:\n%r' % (value, self.body)
435             self._handlewebError(msg)
436    
437     def assertInBody(self, value, msg=None):
438         """Fail if value not in self.body."""
439         if value not in self.body:
440             if msg is None:
441                 msg = '%r not in body: %s' % (value, self.body)
442             self._handlewebError(msg)
443    
444     def assertNotInBody(self, value, msg=None):
445         """Fail if value in self.body."""
446         if value in self.body:
447             if msg is None:
448                 msg = '%r found in body' % value
449             self._handlewebError(msg)
450    
451     def assertMatchesBody(self, pattern, msg=None, flags=0):
452         """Fail if value (a regex pattern) is not in self.body."""
453         if re.search(pattern, self.body, flags) is None:
454             if msg is None:
455                 msg = 'No match for %r in body' % pattern
456             self._handlewebError(msg)
457
458
459 methods_with_bodies = ("POST", "PUT")
460
461 def cleanHeaders(headers, method, body, host, port):
462     """Return request headers, with required headers added (if missing)."""
463     if headers is None:
464         headers = []
465    
466     # Add the required Host request header if not present.
467     # [This specifies the host:port of the server, not the client.]
468     found = False
469     for k, v in headers:
470         if k.lower() == 'host':
471             found = True
472             break
473     if not found:
474         if port == 80:
475             headers.append(("Host", host))
476         else:
477             headers.append(("Host", "%s:%s" % (host, port)))
478    
479     if method in methods_with_bodies:
480         # Stick in default type and length headers if not present
481         found = False
482         for k, v in headers:
483             if k.lower() == 'content-type':
484                 found = True
485                 break
486         if not found:
487             headers.append(("Content-Type", "application/x-www-form-urlencoded"))
488             headers.append(("Content-Length", str(len(body or ""))))
489    
490     return headers
491
492
493 def shb(response):
494     """Return status, headers, body the way we like from a response."""
495     h = []
496     key, value = None, None
497     for line in response.msg.headers:
498         if line:
499             if line[0] in " \t":
500                 value += line.strip()
501             else:
502                 if key and value:
503                     h.append((key, value))
504                 key, value = line.split(":", 1)
505                 key = key.strip()
506                 value = value.strip()
507     if key and value:
508         h.append((key, value))
509    
510     return "%s %s" % (response.status, response.reason), h, response.read()
511
512
513 def openURL(url, headers=None, method="GET", body=None,
514             host="127.0.0.1", port=8000, http_conn=httplib.HTTPConnection,
515             protocol="HTTP/1.1"):
516     """Open the given HTTP resource and return status, headers, and body."""
517    
518     headers = cleanHeaders(headers, method, body, host, port)
519    
520     # Trying 10 times is simply in case of socket errors.
521     # Normal case--it should run once.
522     for trial in range(10):
523         try:
524             # Allow http_conn to be a class or an instance
525             if hasattr(http_conn, "host"):
526                 conn = http_conn
527             else:
528                 conn = http_conn(interface(host), port)
529
530             conn._http_vsn_str = protocol
531             conn._http_vsn = int("".join([x for x in protocol if x.isdigit()]))
532            
533             # skip_accept_encoding argument added in python version 2.4
534             if sys.version_info < (2, 4):
535                 def putheader(self, header, value):
536                     if header == 'Accept-Encoding' and value == 'identity':
537                         return
538                     self.__class__.putheader(self, header, value)
539                 import new
540                 conn.putheader = new.instancemethod(putheader, conn, conn.__class__)
541                 conn.putrequest(method.upper(), url, skip_host=True)
542             else:
543                 conn.putrequest(method.upper(), url, skip_host=True,
544                                 skip_accept_encoding=True)
545            
546             for key, value in headers:
547                 conn.putheader(key, value)
548             conn.endheaders()
549            
550             if body is not None:
551                 conn.send(body)
552            
553             # Handle response
554             response = conn.getresponse()
555            
556             s, h, b = shb(response)
557            
558             if not hasattr(http_conn, "host"):
559                 # We made our own conn instance. Close it.
560                 conn.close()
561            
562             return s, h, b
563         except socket.error:
564             time.sleep(0.5)
565     raise
566
567
568 # Add any exceptions which your web framework handles
569 # normally (that you don't want server_error to trap).
570 ignored_exceptions = []
571
572 # You'll want set this to True when you can't guarantee
573 # that each response will immediately follow each request;
574 # for example, when handling requests via multiple threads.
575 ignore_all = False
576
577 class ServerError(Exception):
578     on = False
579
580
581 def server_error(exc=None):
582     """Server debug hook. Return True if exception handled, False if ignored.
583     
584     You probably want to wrap this, so you can still handle an error using
585     your framework when it's ignored.
586     """
587     if exc is None:
588         exc = sys.exc_info()
589    
590     if ignore_all or exc[0] in ignored_exceptions:
591         return False
592     else:
593         ServerError.on = True
594         print("")
595         print("".join(traceback.format_exception(*exc)))
596         return True
597
Note: See TracBrowser for help on using the browser.

Hosted by WebFaction

Log in as guest/cpguest to create tickets