diff options
-rw-r--r-- | Lib/httplib.py | 30 | ||||
-rw-r--r-- | Lib/test/test_httplib.py | 45 |
2 files changed, 64 insertions, 11 deletions
diff --git a/Lib/httplib.py b/Lib/httplib.py index 5133c8d..c0d372f 100644 --- a/Lib/httplib.py +++ b/Lib/httplib.py @@ -344,6 +344,7 @@ class HTTPResponse: self.will_close = 1 def _check_close(self): + conn = self.msg.getheader('connection') if self.version == 11: # An HTTP/1.1 proxy is assumed to stay open unless # explicitly closed. @@ -352,13 +353,18 @@ class HTTPResponse: return True return False - # An HTTP/1.0 response with a Connection header is probably - # the result of a confused proxy. Ignore it. + # Some HTTP/1.0 implementations have support for persistent + # connections, using rules different than HTTP/1.1. # For older HTTP, Keep-Alive indiciates persistent connection. if self.msg.getheader('keep-alive'): return False + # At least Akamai returns a "Connection: Keep-Alive" header, + # which was supposed to be sent by the client. + if conn and "keep-alive" in conn.lower(): + return False + # Proxy-Connection is a netscape hack. pconn = self.msg.getheader('proxy-connection') if pconn and "keep-alive" in pconn.lower(): @@ -381,6 +387,8 @@ class HTTPResponse: # called, meaning self.isclosed() is meaningful. return self.fp is None + # XXX It would be nice to have readline and __iter__ for this, too. + def read(self, amt=None): if self.fp is None: return '' @@ -728,15 +736,17 @@ class HTTPConnection: self._send_request(method, url, body, headers) def _send_request(self, method, url, body, headers): - # If headers already contains a host header, then define the - # optional skip_host argument to putrequest(). The check is - # harder because field names are case insensitive. - if 'host' in [k.lower() for k in headers]: - self.putrequest(method, url, skip_host=1) - else: - self.putrequest(method, url) + # honour explicitly requested Host: and Accept-Encoding headers + header_names = dict.fromkeys([k.lower() for k in headers]) + skips = {} + if 'host' in header_names: + skips['skip_host'] = 1 + if 'accept-encoding' in header_names: + skips['skip_accept_encoding'] = 1 - if body: + self.putrequest(method, url, **skips) + + if body and ('content-length' not in header_names): self.putheader('Content-Length', str(len(body))) for hdr, value in headers.iteritems(): self.putheader(hdr, value) diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index c57793d..5f252bb 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -2,13 +2,18 @@ import httplib import StringIO import sys -from test.test_support import verify,verbose +from unittest import TestCase + +from test import test_support class FakeSocket: def __init__(self, text, fileclass=StringIO.StringIO): self.text = text self.fileclass = fileclass + def sendall(self, data): + self.data = data + def makefile(self, mode, bufsize=None): if mode != 'r' and mode != 'rb': raise httplib.UnimplementedFileMode() @@ -32,6 +37,39 @@ class NoEOFStringIO(StringIO.StringIO): raise AssertionError('caller tried to read past EOF') return data + +class HeaderTests(TestCase): + def test_auto_headers(self): + # Some headers are added automatically, but should not be added by + # .request() if they are explicitly set. + + import httplib + + class HeaderCountingBuffer(list): + def __init__(self): + self.count = {} + def append(self, item): + kv = item.split(':') + if len(kv) > 1: + # item is a 'Key: Value' header string + lcKey = kv[0].lower() + self.count.setdefault(lcKey, 0) + self.count[lcKey] += 1 + list.append(self, item) + + for explicit_header in True, False: + for header in 'Content-length', 'Host', 'Accept-encoding': + conn = httplib.HTTPConnection('example.com') + conn.sock = FakeSocket('blahblahblah') + conn._buffer = HeaderCountingBuffer() + + body = 'spamspamspam' + headers = {} + if explicit_header: + headers[header] = str(len(body)) + conn.request('POST', '/', body, headers) + self.assertEqual(conn._buffer.count[header.lower()], 1) + # Collect output to a buffer so that we don't have to cope with line-ending # issues across platforms. Specifically, the headers will have \r\n pairs # and some platforms will strip them from the output file. @@ -110,4 +148,9 @@ def _test(): raise AssertionError, "Did not expect response from HEAD request" resp.close() + +def test_main(verbose=None): + tests = [HeaderTests,] + test_support.run_unittest(*tests) + test() |