summaryrefslogtreecommitdiffstats
path: root/Lib/socket.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/socket.py')
-rw-r--r--Lib/socket.py187
1 files changed, 145 insertions, 42 deletions
diff --git a/Lib/socket.py b/Lib/socket.py
index 7f5a91e..e4f0a81 100644
--- a/Lib/socket.py
+++ b/Lib/socket.py
@@ -16,7 +16,7 @@ gethostname() -- return the current hostname
gethostbyname() -- map a hostname to its IP number
gethostbyaddr() -- map an IP number or hostname to DNS info
getservbyname() -- map a service name and a protocol name to a port number
-getprotobyname() -- mape a protocol name (e.g. 'tcp') to a number
+getprotobyname() -- map a protocol name (e.g. 'tcp') to a number
ntohs(), ntohl() -- convert 16, 32 bit int from network to host byte order
htons(), htonl() -- convert 16, 32 bit int from host to network byte order
inet_aton() -- convert IP addr string (123.45.67.89) to 32-bit packed format
@@ -24,6 +24,7 @@ inet_ntoa() -- convert 32-bit packed format IP to string (123.45.67.89)
ssl() -- secure socket layer support (only available if configured)
socket.getdefaulttimeout() -- get the default timeout value
socket.setdefaulttimeout() -- set the default timeout value
+create_connection() -- connects to an address, with an optional timeout
[*] not available on all platforms!
@@ -45,15 +46,37 @@ the setsockopt() and getsockopt() methods.
import _socket
from _socket import *
-_have_ssl = False
try:
import _ssl
- from _ssl import *
- _have_ssl = True
except ImportError:
+ # no SSL support
pass
-
-import os, sys
+else:
+ def ssl(sock, keyfile=None, certfile=None):
+ # we do an internal import here because the ssl
+ # module imports the socket module
+ import ssl as _realssl
+ warnings.warn("socket.ssl() is deprecated. Use ssl.wrap_socket() instead.",
+ DeprecationWarning, stacklevel=2)
+ return _realssl.sslwrap_simple(sock, keyfile, certfile)
+
+ # we need to import the same constants we used to...
+ from _ssl import SSLError as sslerror
+ from _ssl import \
+ RAND_add, \
+ RAND_egd, \
+ RAND_status, \
+ SSL_ERROR_ZERO_RETURN, \
+ SSL_ERROR_WANT_READ, \
+ SSL_ERROR_WANT_WRITE, \
+ SSL_ERROR_WANT_X509_LOOKUP, \
+ SSL_ERROR_SYSCALL, \
+ SSL_ERROR_SSL, \
+ SSL_ERROR_WANT_CONNECT, \
+ SSL_ERROR_EOF, \
+ SSL_ERROR_INVALID_ERROR_CODE
+
+import os, sys, warnings
try:
from cStringIO import StringIO
@@ -61,22 +84,17 @@ except ImportError:
from StringIO import StringIO
try:
- from errno import EBADF
+ import errno
except ImportError:
- EBADF = 9
+ errno = None
+EBADF = getattr(errno, 'EBADF', 9)
+EINTR = getattr(errno, 'EINTR', 4)
-__all__ = ["getfqdn"]
+__all__ = ["getfqdn", "create_connection"]
__all__.extend(os._get_exports_list(_socket))
-if _have_ssl:
- __all__.extend(os._get_exports_list(_ssl))
+
_realsocket = socket
-if _have_ssl:
- _realssl = ssl
- def ssl(sock, keyfile=None, certfile=None):
- if hasattr(sock, "_sock"):
- sock = sock._sock
- return _realssl(sock, keyfile, certfile)
# WSA error codes
if sys.platform.lower().startswith("win"):
@@ -132,6 +150,9 @@ _socketmethods = (
'sendall', 'setblocking',
'settimeout', 'gettimeout', 'shutdown')
+if os.name == "nt":
+ _socketmethods = _socketmethods + ('ioctl',)
+
if sys.platform == "riscos":
_socketmethods = _socketmethods + ('sleeptaskw',)
@@ -148,6 +169,10 @@ class _closedsocket(object):
send = recv = recv_into = sendto = recvfrom = recvfrom_into = _dummy
__getattr__ = _dummy
+# Wrapper around platform socket objects. This implements
+# a platform-independent dup() functionality. The
+# implementation currently relies on reference counting
+# to close the underlying socket object.
class _socketobject(object):
__doc__ = _realsocket.__doc__
@@ -206,7 +231,7 @@ class _fileobject(object):
__slots__ = ["mode", "bufsize", "softspace",
# "closed" is a property, see below
- "_sock", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf",
+ "_sock", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf", "_wbuf_len",
"_close"]
def __init__(self, sock, mode='rb', bufsize=-1, close=False):
@@ -232,6 +257,7 @@ class _fileobject(object):
# realloc()ed down much smaller than their original allocation.
self._rbuf = StringIO()
self._wbuf = [] # A list of strings
+ self._wbuf_len = 0
self._close = close
def _getclosed(self):
@@ -256,9 +282,26 @@ class _fileobject(object):
def flush(self):
if self._wbuf:
- buffer = "".join(self._wbuf)
+ data = "".join(self._wbuf)
self._wbuf = []
- self._sock.sendall(buffer)
+ self._wbuf_len = 0
+ buffer_size = max(self._rbufsize, self.default_bufsize)
+ data_size = len(data)
+ write_offset = 0
+ try:
+ while write_offset < data_size:
+ with warnings.catch_warnings():
+ if sys.py3kwarning:
+ warnings.filterwarnings("ignore", ".*buffer",
+ DeprecationWarning)
+ self._sock.sendall(buffer(data, write_offset, buffer_size))
+ write_offset += buffer_size
+ finally:
+ if write_offset < data_size:
+ remainder = data[write_offset:]
+ del data # explicit free
+ self._wbuf.append(remainder)
+ self._wbuf_len = len(remainder)
def fileno(self):
return self._sock.fileno()
@@ -268,24 +311,24 @@ class _fileobject(object):
if not data:
return
self._wbuf.append(data)
+ self._wbuf_len += len(data)
if (self._wbufsize == 0 or
self._wbufsize == 1 and '\n' in data or
- self._get_wbuf_len() >= self._wbufsize):
+ self._wbuf_len >= self._wbufsize):
self.flush()
def writelines(self, list):
# XXX We could do better here for very long lists
# XXX Should really reject non-string non-buffers
- self._wbuf.extend(filter(None, map(str, list)))
+ lines = filter(None, map(str, list))
+ self._wbuf_len += sum(map(len, lines))
+ self._wbuf.extend(lines)
if (self._wbufsize <= 1 or
- self._get_wbuf_len() >= self._wbufsize):
+ self._wbuf_len >= self._wbufsize):
self.flush()
def _get_wbuf_len(self):
- buf_len = 0
- for x in self._wbuf:
- buf_len += len(x)
- return buf_len
+ return self._wbuf_len
def read(self, size=-1):
# Use max, disallow tiny reads in a loop as they are very inefficient.
@@ -301,7 +344,12 @@ class _fileobject(object):
# Read until EOF
self._rbuf = StringIO() # reset _rbuf. we consume it via buf.
while True:
- data = self._sock.recv(rbufsize)
+ try:
+ data = self._sock.recv(rbufsize)
+ except error, e:
+ if e.args[0] == EINTR:
+ continue
+ raise
if not data:
break
buf.write(data)
@@ -325,7 +373,12 @@ class _fileobject(object):
# than that. The returned data string is short lived
# as we copy it into a StringIO and free it. This avoids
# fragmentation issues on many platforms.
- data = self._sock.recv(left)
+ try:
+ data = self._sock.recv(left)
+ except error, e:
+ if e.args[0] == EINTR:
+ continue
+ raise
if not data:
break
n = len(data)
@@ -368,24 +421,38 @@ class _fileobject(object):
self._rbuf = StringIO() # reset _rbuf. we consume it via buf.
data = None
recv = self._sock.recv
- while data != "\n":
- data = recv(1)
- if not data:
- break
- buffers.append(data)
+ while True:
+ try:
+ while data != "\n":
+ data = recv(1)
+ if not data:
+ break
+ buffers.append(data)
+ except error, e:
+ # The try..except to catch EINTR was moved outside the
+ # recv loop to avoid the per byte overhead.
+ if e.args[0] == EINTR:
+ continue
+ raise
+ break
return "".join(buffers)
buf.seek(0, 2) # seek end
self._rbuf = StringIO() # reset _rbuf. we consume it via buf.
while True:
- data = self._sock.recv(self._rbufsize)
+ try:
+ data = self._sock.recv(self._rbufsize)
+ except error, e:
+ if e.args[0] == EINTR:
+ continue
+ raise
if not data:
break
nl = data.find('\n')
if nl >= 0:
nl += 1
- buf.write(buffer(data, 0, nl))
- self._rbuf.write(buffer(data, nl))
+ buf.write(data[:nl])
+ self._rbuf.write(data[nl:])
del data
break
buf.write(data)
@@ -402,7 +469,12 @@ class _fileobject(object):
return rv
self._rbuf = StringIO() # reset _rbuf. we consume it via buf.
while True:
- data = self._sock.recv(self._rbufsize)
+ try:
+ data = self._sock.recv(self._rbufsize)
+ except error, e:
+ if e.args[0] == EINTR:
+ continue
+ raise
if not data:
break
left = size - buf_len
@@ -411,9 +483,9 @@ class _fileobject(object):
if nl >= 0:
nl += 1
# save the excess data to _rbuf
- self._rbuf.write(buffer(data, nl))
+ self._rbuf.write(data[nl:])
if buf_len:
- buf.write(buffer(data, 0, nl))
+ buf.write(data[:nl])
break
else:
# Shortcut. Avoid data copy through buf when returning
@@ -425,8 +497,8 @@ class _fileobject(object):
# returning exactly all of our first recv().
return data
if n >= left:
- buf.write(buffer(data, 0, left))
- self._rbuf.write(buffer(data, left))
+ buf.write(data[:left])
+ self._rbuf.write(data[left:])
break
buf.write(data)
buf_len += n
@@ -456,3 +528,34 @@ class _fileobject(object):
if not line:
raise StopIteration
return line
+
+_GLOBAL_DEFAULT_TIMEOUT = object()
+
+def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT):
+ """Connect to *address* and return the socket object.
+
+ Convenience function. Connect to *address* (a 2-tuple ``(host,
+ port)``) and return the socket object. Passing the optional
+ *timeout* parameter will set the timeout on the socket instance
+ before attempting to connect. If no *timeout* is supplied, the
+ global default timeout setting returned by :func:`getdefaulttimeout`
+ is used.
+ """
+
+ msg = "getaddrinfo returns an empty list"
+ host, port = address
+ for res in getaddrinfo(host, port, 0, SOCK_STREAM):
+ af, socktype, proto, canonname, sa = res
+ sock = None
+ try:
+ sock = socket(af, socktype, proto)
+ if timeout is not _GLOBAL_DEFAULT_TIMEOUT:
+ sock.settimeout(timeout)
+ sock.connect(sa)
+ return sock
+
+ except error, msg:
+ if sock is not None:
+ sock.close()
+
+ raise error, msg