From 9da047b3a5d1a5421ee829aab1e9bc57427826c9 Mon Sep 17 00:00:00 2001 From: Senthil Kumaran Date: Mon, 14 Apr 2014 13:07:56 -0400 Subject: Issue #7776: Fix ``Host:'' header and reconnection when using http.client.HTTPConnection.set_tunnel(). Patch by Nikolaus Rath. --- Lib/http/client.py | 73 ++++++++++++++++++++++++++++++++---------------- Lib/test/test_httplib.py | 50 +++++++++++++++++++++++++++++++-- Misc/NEWS | 3 ++ 3 files changed, 100 insertions(+), 26 deletions(-) diff --git a/Lib/http/client.py b/Lib/http/client.py index 12c1a5f..d2013f2 100644 --- a/Lib/http/client.py +++ b/Lib/http/client.py @@ -747,14 +747,30 @@ class HTTPConnection: self._tunnel_port = None self._tunnel_headers = {} - self._set_hostport(host, port) + (self.host, self.port) = self._get_hostport(host, port) + + # This is stored as an instance variable to allow unit + # tests to replace it with a suitable mockup + self._create_connection = socket.create_connection def set_tunnel(self, host, port=None, headers=None): - """ Sets up the host and the port for the HTTP CONNECT Tunnelling. + """Set up host and port for HTTP CONNECT tunnelling. + + In a connection that uses HTTP CONNECT tunneling, the host passed to the + constructor is used as a proxy server that relays all communication to + the endpoint passed to `set_tunnel`. This done by sending an HTTP + CONNECT request to the proxy server when the connection is established. - The headers argument should be a mapping of extra HTTP headers - to send with the CONNECT request. + This method must be called before the HTML connection has been + established. + + The headers argument should be a mapping of extra HTTP headers to send + with the CONNECT request. """ + + if self.sock: + raise RuntimeError("Can't set up tunnel for established connection") + self._tunnel_host = host self._tunnel_port = port if headers: @@ -762,7 +778,7 @@ class HTTPConnection: else: self._tunnel_headers.clear() - def _set_hostport(self, host, port): + def _get_hostport(self, host, port): if port is None: i = host.rfind(':') j = host.rfind(']') # ipv6 addresses have [...] @@ -779,15 +795,16 @@ class HTTPConnection: port = self.default_port if host and host[0] == '[' and host[-1] == ']': host = host[1:-1] - self.host = host - self.port = port + + return (host, port) def set_debuglevel(self, level): self.debuglevel = level def _tunnel(self): - self._set_hostport(self._tunnel_host, self._tunnel_port) - connect_str = "CONNECT %s:%d HTTP/1.0\r\n" % (self.host, self.port) + (host, port) = self._get_hostport(self._tunnel_host, + self._tunnel_port) + connect_str = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port) connect_bytes = connect_str.encode("ascii") self.send(connect_bytes) for header, value in self._tunnel_headers.items(): @@ -815,8 +832,9 @@ class HTTPConnection: def connect(self): """Connect to the host and port specified in __init__.""" - self.sock = socket.create_connection((self.host,self.port), - self.timeout, self.source_address) + self.sock = self._create_connection((self.host,self.port), + self.timeout, self.source_address) + if self._tunnel_host: self._tunnel() @@ -985,22 +1003,29 @@ class HTTPConnection: netloc_enc = netloc.encode("idna") self.putheader('Host', netloc_enc) else: + if self._tunnel_host: + host = self._tunnel_host + port = self._tunnel_port + else: + host = self.host + port = self.port + try: - host_enc = self.host.encode("ascii") + host_enc = host.encode("ascii") except UnicodeEncodeError: - host_enc = self.host.encode("idna") + host_enc = host.encode("idna") # As per RFC 273, IPv6 address should be wrapped with [] # when used as Host header - if self.host.find(':') >= 0: + if host.find(':') >= 0: host_enc = b'[' + host_enc + b']' - if self.port == self.default_port: + if port == self.default_port: self.putheader('Host', host_enc) else: host_enc = host_enc.decode("ascii") - self.putheader('Host', "%s:%s" % (host_enc, self.port)) + self.putheader('Host', "%s:%s" % (host_enc, port)) # note: we are assuming that clients will not attempt to set these # headers since *this* library must deal with the @@ -1193,19 +1218,19 @@ else: def connect(self): "Connect to a host on a given (SSL) port." - sock = socket.create_connection((self.host, self.port), - self.timeout, self.source_address) + super().connect() if self._tunnel_host: - self.sock = sock - self._tunnel() + server_hostname = self._tunnel_host + else: + server_hostname = self.host + sni_hostname = server_hostname if ssl.HAS_SNI else None - server_hostname = self.host if ssl.HAS_SNI else None - self.sock = self._context.wrap_socket(sock, - server_hostname=server_hostname) + self.sock = self._context.wrap_socket(self.sock, + server_hostname=sni_hostname) if not self._context.check_hostname and self._check_hostname: try: - ssl.match_hostname(self.sock.getpeercert(), self.host) + ssl.match_hostname(self.sock.getpeercert(), server_hostname) except Exception: self.sock.shutdown(socket.SHUT_RDWR) self.sock.close() diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index 30b6c0c..22f7329 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -21,13 +21,15 @@ CACERT_svn_python_org = os.path.join(here, 'https_svn_python_org_root.pem') HOST = support.HOST class FakeSocket: - def __init__(self, text, fileclass=io.BytesIO): + def __init__(self, text, fileclass=io.BytesIO, host=None, port=None): if isinstance(text, str): text = text.encode("ascii") self.text = text self.fileclass = fileclass self.data = b'' self.sendall_calls = 0 + self.host = host + self.port = port def sendall(self, data): self.sendall_calls += 1 @@ -38,6 +40,9 @@ class FakeSocket: raise client.UnimplementedFileMode() return self.fileclass(self.text) + def close(self): + pass + class EPipeSocket(FakeSocket): def __init__(self, text, pipe_trigger): @@ -970,10 +975,51 @@ class HTTPResponseTest(TestCase): header = self.resp.getheader('No-Such-Header',default=42) self.assertEqual(header, 42) +class TunnelTests(TestCase): + + def test_connect(self): + response_text = ( + 'HTTP/1.0 200 OK\r\n\r\n' # Reply to CONNECT + 'HTTP/1.1 200 OK\r\n' # Reply to HEAD + 'Content-Length: 42\r\n\r\n' + ) + + def create_connection(address, timeout=None, source_address=None): + return FakeSocket(response_text, host=address[0], + port=address[1]) + + conn = client.HTTPConnection('proxy.com') + conn._create_connection = create_connection + + # Once connected, we shouldn't be able to tunnel anymore + conn.connect() + self.assertRaises(RuntimeError, conn.set_tunnel, + 'destination.com') + + # But if we close the connection, we're good + conn.close() + conn.set_tunnel('destination.com') + conn.request('HEAD', '/', '') + + self.assertEqual(conn.sock.host, 'proxy.com') + self.assertEqual(conn.sock.port, 80) + self.assertTrue(b'CONNECT destination.com' in conn.sock.data) + self.assertTrue(b'Host: destination.com' in conn.sock.data) + + # This test should be removed when CONNECT gets the HTTP/1.1 blessing + self.assertTrue(b'Host: proxy.com' not in conn.sock.data) + + conn.close() + conn.request('PUT', '/', '') + self.assertEqual(conn.sock.host, 'proxy.com') + self.assertEqual(conn.sock.port, 80) + self.assertTrue(b'CONNECT destination.com' in conn.sock.data) + self.assertTrue(b'Host: destination.com' in conn.sock.data) + def test_main(verbose=None): support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest, HTTPSTest, RequestBodyTest, SourceAddressTest, - HTTPResponseTest) + HTTPResponseTest, TunnelTests) if __name__ == '__main__': test_main() diff --git a/Misc/NEWS b/Misc/NEWS index 240c6cc..9629b5e 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -33,6 +33,9 @@ Core and Builtins Library ------- +- Issue #7776: Fix ``Host:'' header and reconnection when using + http.client.HTTPConnection.set_tunnel(). Patch by Nikolaus Rath. + - Issue #20968: unittest.mock.MagicMock now supports division. Patch by Johannes Baiter. -- cgit v0.12