summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorYury Selivanov <yury@magic.io>2017-12-19 01:02:54 (GMT)
committerGitHub <noreply@github.com>2017-12-19 01:02:54 (GMT)
commit9818142b1bd20243733a953fb8aa2c7be314c47c (patch)
tree625350fae6c199ae5442118eaf36db480fe00046 /Lib
parent6efcb6d3d5911aaf699f9df3bb3bc26e94f38e6d (diff)
downloadcpython-9818142b1bd20243733a953fb8aa2c7be314c47c.zip
cpython-9818142b1bd20243733a953fb8aa2c7be314c47c.tar.gz
cpython-9818142b1bd20243733a953fb8aa2c7be314c47c.tar.bz2
bpo-32331: Fix socket.type when SOCK_NONBLOCK is available (#4877)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/socket.py6
-rw-r--r--Lib/test/test_asyncore.py10
-rw-r--r--Lib/test/test_socket.py41
3 files changed, 39 insertions, 18 deletions
diff --git a/Lib/socket.py b/Lib/socket.py
index 1ada24d..2d8aee3 100644
--- a/Lib/socket.py
+++ b/Lib/socket.py
@@ -203,11 +203,7 @@ class socket(_socket.socket):
For IP sockets, the address info is a pair (hostaddr, port).
"""
fd, addr = self._accept()
- # If our type has the SOCK_NONBLOCK flag, we shouldn't pass it onto the
- # new socket. We do not currently allow passing SOCK_NONBLOCK to
- # accept4, so the returned socket is always blocking.
- type = self.type & ~globals().get("SOCK_NONBLOCK", 0)
- sock = socket(self.family, type, self.proto, fileno=fd)
+ sock = socket(self.family, self.type, self.proto, fileno=fd)
# Issue #7995: if no default timeout is set and the listening
# socket had a (non-zero) timeout, force the new socket in blocking
# mode to override platform-specific socket flags inheritance.
diff --git a/Lib/test/test_asyncore.py b/Lib/test/test_asyncore.py
index ee0c3b3..694ddff 100644
--- a/Lib/test/test_asyncore.py
+++ b/Lib/test/test_asyncore.py
@@ -726,14 +726,10 @@ class BaseTestAPI:
def test_create_socket(self):
s = asyncore.dispatcher()
s.create_socket(self.family)
+ self.assertEqual(s.socket.type, socket.SOCK_STREAM)
self.assertEqual(s.socket.family, self.family)
- SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', 0)
- sock_type = socket.SOCK_STREAM | SOCK_NONBLOCK
- if hasattr(socket, 'SOCK_CLOEXEC'):
- self.assertIn(s.socket.type,
- (sock_type | socket.SOCK_CLOEXEC, sock_type))
- else:
- self.assertEqual(s.socket.type, sock_type)
+ self.assertEqual(s.socket.gettimeout(), 0)
+ self.assertFalse(s.socket.get_inheritable())
def test_bind(self):
if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index 5b4c5f9..43688ea 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -1577,6 +1577,22 @@ class GeneralModuleTests(unittest.TestCase):
self.assertEqual(str(s.family), 'AddressFamily.AF_INET')
self.assertEqual(str(s.type), 'SocketKind.SOCK_STREAM')
+ def test_socket_consistent_sock_type(self):
+ SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', 0)
+ SOCK_CLOEXEC = getattr(socket, 'SOCK_CLOEXEC', 0)
+ sock_type = socket.SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC
+
+ with socket.socket(socket.AF_INET, sock_type) as s:
+ self.assertEqual(s.type, socket.SOCK_STREAM)
+ s.settimeout(1)
+ self.assertEqual(s.type, socket.SOCK_STREAM)
+ s.settimeout(0)
+ self.assertEqual(s.type, socket.SOCK_STREAM)
+ s.setblocking(True)
+ self.assertEqual(s.type, socket.SOCK_STREAM)
+ s.setblocking(False)
+ self.assertEqual(s.type, socket.SOCK_STREAM)
+
@unittest.skipIf(os.name == 'nt', 'Will not work on Windows')
def test_uknown_socket_family_repr(self):
# Test that when created with a family that's not one of the known
@@ -1589,9 +1605,18 @@ class GeneralModuleTests(unittest.TestCase):
# On Windows this trick won't work, so the test is skipped.
fd, path = tempfile.mkstemp()
self.addCleanup(os.unlink, path)
- with socket.socket(family=42424, type=13331, fileno=fd) as s:
- self.assertEqual(s.family, 42424)
- self.assertEqual(s.type, 13331)
+ unknown_family = max(socket.AddressFamily.__members__.values()) + 1
+
+ unknown_type = max(
+ kind
+ for name, kind in socket.SocketKind.__members__.items()
+ if name not in {'SOCK_NONBLOCK', 'SOCK_CLOEXEC'}
+ ) + 1
+
+ with socket.socket(
+ family=unknown_family, type=unknown_type, fileno=fd) as s:
+ self.assertEqual(s.family, unknown_family)
+ self.assertEqual(s.type, unknown_type)
@unittest.skipUnless(hasattr(os, 'sendfile'), 'test needs os.sendfile()')
def test__sendfile_use_sendfile(self):
@@ -5084,7 +5109,7 @@ class InheritanceTest(unittest.TestCase):
def test_SOCK_CLOEXEC(self):
with socket.socket(socket.AF_INET,
socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s:
- self.assertTrue(s.type & socket.SOCK_CLOEXEC)
+ self.assertEqual(s.type, socket.SOCK_STREAM)
self.assertFalse(s.get_inheritable())
def test_default_inheritable(self):
@@ -5149,11 +5174,15 @@ class InheritanceTest(unittest.TestCase):
class NonblockConstantTest(unittest.TestCase):
def checkNonblock(self, s, nonblock=True, timeout=0.0):
if nonblock:
- self.assertTrue(s.type & socket.SOCK_NONBLOCK)
+ self.assertEqual(s.type, socket.SOCK_STREAM)
self.assertEqual(s.gettimeout(), timeout)
+ self.assertTrue(
+ fcntl.fcntl(s, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK)
else:
- self.assertFalse(s.type & socket.SOCK_NONBLOCK)
+ self.assertEqual(s.type, socket.SOCK_STREAM)
self.assertEqual(s.gettimeout(), None)
+ self.assertFalse(
+ fcntl.fcntl(s, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK)
@support.requires_linux_version(2, 6, 28)
def test_SOCK_NONBLOCK(self):