diff options
Diffstat (limited to 'Lib/socket.py')
| -rw-r--r-- | Lib/socket.py | 187 |
1 files changed, 145 insertions, 42 deletions
diff --git a/Lib/socket.py b/Lib/socket.py index 7f5a91e..e4f0a81 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -16,7 +16,7 @@ gethostname() -- return the current hostname gethostbyname() -- map a hostname to its IP number gethostbyaddr() -- map an IP number or hostname to DNS info getservbyname() -- map a service name and a protocol name to a port number -getprotobyname() -- mape a protocol name (e.g. 'tcp') to a number +getprotobyname() -- map a protocol name (e.g. 'tcp') to a number ntohs(), ntohl() -- convert 16, 32 bit int from network to host byte order htons(), htonl() -- convert 16, 32 bit int from host to network byte order inet_aton() -- convert IP addr string (123.45.67.89) to 32-bit packed format @@ -24,6 +24,7 @@ inet_ntoa() -- convert 32-bit packed format IP to string (123.45.67.89) ssl() -- secure socket layer support (only available if configured) socket.getdefaulttimeout() -- get the default timeout value socket.setdefaulttimeout() -- set the default timeout value +create_connection() -- connects to an address, with an optional timeout [*] not available on all platforms! @@ -45,15 +46,37 @@ the setsockopt() and getsockopt() methods. import _socket from _socket import * -_have_ssl = False try: import _ssl - from _ssl import * - _have_ssl = True except ImportError: + # no SSL support pass - -import os, sys +else: + def ssl(sock, keyfile=None, certfile=None): + # we do an internal import here because the ssl + # module imports the socket module + import ssl as _realssl + warnings.warn("socket.ssl() is deprecated. Use ssl.wrap_socket() instead.", + DeprecationWarning, stacklevel=2) + return _realssl.sslwrap_simple(sock, keyfile, certfile) + + # we need to import the same constants we used to... + from _ssl import SSLError as sslerror + from _ssl import \ + RAND_add, \ + RAND_egd, \ + RAND_status, \ + SSL_ERROR_ZERO_RETURN, \ + SSL_ERROR_WANT_READ, \ + SSL_ERROR_WANT_WRITE, \ + SSL_ERROR_WANT_X509_LOOKUP, \ + SSL_ERROR_SYSCALL, \ + SSL_ERROR_SSL, \ + SSL_ERROR_WANT_CONNECT, \ + SSL_ERROR_EOF, \ + SSL_ERROR_INVALID_ERROR_CODE + +import os, sys, warnings try: from cStringIO import StringIO @@ -61,22 +84,17 @@ except ImportError: from StringIO import StringIO try: - from errno import EBADF + import errno except ImportError: - EBADF = 9 + errno = None +EBADF = getattr(errno, 'EBADF', 9) +EINTR = getattr(errno, 'EINTR', 4) -__all__ = ["getfqdn"] +__all__ = ["getfqdn", "create_connection"] __all__.extend(os._get_exports_list(_socket)) -if _have_ssl: - __all__.extend(os._get_exports_list(_ssl)) + _realsocket = socket -if _have_ssl: - _realssl = ssl - def ssl(sock, keyfile=None, certfile=None): - if hasattr(sock, "_sock"): - sock = sock._sock - return _realssl(sock, keyfile, certfile) # WSA error codes if sys.platform.lower().startswith("win"): @@ -132,6 +150,9 @@ _socketmethods = ( 'sendall', 'setblocking', 'settimeout', 'gettimeout', 'shutdown') +if os.name == "nt": + _socketmethods = _socketmethods + ('ioctl',) + if sys.platform == "riscos": _socketmethods = _socketmethods + ('sleeptaskw',) @@ -148,6 +169,10 @@ class _closedsocket(object): send = recv = recv_into = sendto = recvfrom = recvfrom_into = _dummy __getattr__ = _dummy +# Wrapper around platform socket objects. This implements +# a platform-independent dup() functionality. The +# implementation currently relies on reference counting +# to close the underlying socket object. class _socketobject(object): __doc__ = _realsocket.__doc__ @@ -206,7 +231,7 @@ class _fileobject(object): __slots__ = ["mode", "bufsize", "softspace", # "closed" is a property, see below - "_sock", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf", + "_sock", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf", "_wbuf_len", "_close"] def __init__(self, sock, mode='rb', bufsize=-1, close=False): @@ -232,6 +257,7 @@ class _fileobject(object): # realloc()ed down much smaller than their original allocation. self._rbuf = StringIO() self._wbuf = [] # A list of strings + self._wbuf_len = 0 self._close = close def _getclosed(self): @@ -256,9 +282,26 @@ class _fileobject(object): def flush(self): if self._wbuf: - buffer = "".join(self._wbuf) + data = "".join(self._wbuf) self._wbuf = [] - self._sock.sendall(buffer) + self._wbuf_len = 0 + buffer_size = max(self._rbufsize, self.default_bufsize) + data_size = len(data) + write_offset = 0 + try: + while write_offset < data_size: + with warnings.catch_warnings(): + if sys.py3kwarning: + warnings.filterwarnings("ignore", ".*buffer", + DeprecationWarning) + self._sock.sendall(buffer(data, write_offset, buffer_size)) + write_offset += buffer_size + finally: + if write_offset < data_size: + remainder = data[write_offset:] + del data # explicit free + self._wbuf.append(remainder) + self._wbuf_len = len(remainder) def fileno(self): return self._sock.fileno() @@ -268,24 +311,24 @@ class _fileobject(object): if not data: return self._wbuf.append(data) + self._wbuf_len += len(data) if (self._wbufsize == 0 or self._wbufsize == 1 and '\n' in data or - self._get_wbuf_len() >= self._wbufsize): + self._wbuf_len >= self._wbufsize): self.flush() def writelines(self, list): # XXX We could do better here for very long lists # XXX Should really reject non-string non-buffers - self._wbuf.extend(filter(None, map(str, list))) + lines = filter(None, map(str, list)) + self._wbuf_len += sum(map(len, lines)) + self._wbuf.extend(lines) if (self._wbufsize <= 1 or - self._get_wbuf_len() >= self._wbufsize): + self._wbuf_len >= self._wbufsize): self.flush() def _get_wbuf_len(self): - buf_len = 0 - for x in self._wbuf: - buf_len += len(x) - return buf_len + return self._wbuf_len def read(self, size=-1): # Use max, disallow tiny reads in a loop as they are very inefficient. @@ -301,7 +344,12 @@ class _fileobject(object): # Read until EOF self._rbuf = StringIO() # reset _rbuf. we consume it via buf. while True: - data = self._sock.recv(rbufsize) + try: + data = self._sock.recv(rbufsize) + except error, e: + if e.args[0] == EINTR: + continue + raise if not data: break buf.write(data) @@ -325,7 +373,12 @@ class _fileobject(object): # than that. The returned data string is short lived # as we copy it into a StringIO and free it. This avoids # fragmentation issues on many platforms. - data = self._sock.recv(left) + try: + data = self._sock.recv(left) + except error, e: + if e.args[0] == EINTR: + continue + raise if not data: break n = len(data) @@ -368,24 +421,38 @@ class _fileobject(object): self._rbuf = StringIO() # reset _rbuf. we consume it via buf. data = None recv = self._sock.recv - while data != "\n": - data = recv(1) - if not data: - break - buffers.append(data) + while True: + try: + while data != "\n": + data = recv(1) + if not data: + break + buffers.append(data) + except error, e: + # The try..except to catch EINTR was moved outside the + # recv loop to avoid the per byte overhead. + if e.args[0] == EINTR: + continue + raise + break return "".join(buffers) buf.seek(0, 2) # seek end self._rbuf = StringIO() # reset _rbuf. we consume it via buf. while True: - data = self._sock.recv(self._rbufsize) + try: + data = self._sock.recv(self._rbufsize) + except error, e: + if e.args[0] == EINTR: + continue + raise if not data: break nl = data.find('\n') if nl >= 0: nl += 1 - buf.write(buffer(data, 0, nl)) - self._rbuf.write(buffer(data, nl)) + buf.write(data[:nl]) + self._rbuf.write(data[nl:]) del data break buf.write(data) @@ -402,7 +469,12 @@ class _fileobject(object): return rv self._rbuf = StringIO() # reset _rbuf. we consume it via buf. while True: - data = self._sock.recv(self._rbufsize) + try: + data = self._sock.recv(self._rbufsize) + except error, e: + if e.args[0] == EINTR: + continue + raise if not data: break left = size - buf_len @@ -411,9 +483,9 @@ class _fileobject(object): if nl >= 0: nl += 1 # save the excess data to _rbuf - self._rbuf.write(buffer(data, nl)) + self._rbuf.write(data[nl:]) if buf_len: - buf.write(buffer(data, 0, nl)) + buf.write(data[:nl]) break else: # Shortcut. Avoid data copy through buf when returning @@ -425,8 +497,8 @@ class _fileobject(object): # returning exactly all of our first recv(). return data if n >= left: - buf.write(buffer(data, 0, left)) - self._rbuf.write(buffer(data, left)) + buf.write(data[:left]) + self._rbuf.write(data[left:]) break buf.write(data) buf_len += n @@ -456,3 +528,34 @@ class _fileobject(object): if not line: raise StopIteration return line + +_GLOBAL_DEFAULT_TIMEOUT = object() + +def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT): + """Connect to *address* and return the socket object. + + Convenience function. Connect to *address* (a 2-tuple ``(host, + port)``) and return the socket object. Passing the optional + *timeout* parameter will set the timeout on the socket instance + before attempting to connect. If no *timeout* is supplied, the + global default timeout setting returned by :func:`getdefaulttimeout` + is used. + """ + + msg = "getaddrinfo returns an empty list" + host, port = address + for res in getaddrinfo(host, port, 0, SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = None + try: + sock = socket(af, socktype, proto) + if timeout is not _GLOBAL_DEFAULT_TIMEOUT: + sock.settimeout(timeout) + sock.connect(sa) + return sock + + except error, msg: + if sock is not None: + sock.close() + + raise error, msg |
