diff options
author | Yury Selivanov <yury@magic.io> | 2017-12-19 01:02:54 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-12-19 01:02:54 (GMT) |
commit | 9818142b1bd20243733a953fb8aa2c7be314c47c (patch) | |
tree | 625350fae6c199ae5442118eaf36db480fe00046 /Lib | |
parent | 6efcb6d3d5911aaf699f9df3bb3bc26e94f38e6d (diff) | |
download | cpython-9818142b1bd20243733a953fb8aa2c7be314c47c.zip cpython-9818142b1bd20243733a953fb8aa2c7be314c47c.tar.gz cpython-9818142b1bd20243733a953fb8aa2c7be314c47c.tar.bz2 |
bpo-32331: Fix socket.type when SOCK_NONBLOCK is available (#4877)
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/socket.py | 6 | ||||
-rw-r--r-- | Lib/test/test_asyncore.py | 10 | ||||
-rw-r--r-- | Lib/test/test_socket.py | 41 |
3 files changed, 39 insertions, 18 deletions
diff --git a/Lib/socket.py b/Lib/socket.py index 1ada24d..2d8aee3 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -203,11 +203,7 @@ class socket(_socket.socket): For IP sockets, the address info is a pair (hostaddr, port). """ fd, addr = self._accept() - # If our type has the SOCK_NONBLOCK flag, we shouldn't pass it onto the - # new socket. We do not currently allow passing SOCK_NONBLOCK to - # accept4, so the returned socket is always blocking. - type = self.type & ~globals().get("SOCK_NONBLOCK", 0) - sock = socket(self.family, type, self.proto, fileno=fd) + sock = socket(self.family, self.type, self.proto, fileno=fd) # Issue #7995: if no default timeout is set and the listening # socket had a (non-zero) timeout, force the new socket in blocking # mode to override platform-specific socket flags inheritance. diff --git a/Lib/test/test_asyncore.py b/Lib/test/test_asyncore.py index ee0c3b3..694ddff 100644 --- a/Lib/test/test_asyncore.py +++ b/Lib/test/test_asyncore.py @@ -726,14 +726,10 @@ class BaseTestAPI: def test_create_socket(self): s = asyncore.dispatcher() s.create_socket(self.family) + self.assertEqual(s.socket.type, socket.SOCK_STREAM) self.assertEqual(s.socket.family, self.family) - SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', 0) - sock_type = socket.SOCK_STREAM | SOCK_NONBLOCK - if hasattr(socket, 'SOCK_CLOEXEC'): - self.assertIn(s.socket.type, - (sock_type | socket.SOCK_CLOEXEC, sock_type)) - else: - self.assertEqual(s.socket.type, sock_type) + self.assertEqual(s.socket.gettimeout(), 0) + self.assertFalse(s.socket.get_inheritable()) def test_bind(self): if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 5b4c5f9..43688ea 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -1577,6 +1577,22 @@ class GeneralModuleTests(unittest.TestCase): self.assertEqual(str(s.family), 'AddressFamily.AF_INET') self.assertEqual(str(s.type), 'SocketKind.SOCK_STREAM') + def test_socket_consistent_sock_type(self): + SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', 0) + SOCK_CLOEXEC = getattr(socket, 'SOCK_CLOEXEC', 0) + sock_type = socket.SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC + + with socket.socket(socket.AF_INET, sock_type) as s: + self.assertEqual(s.type, socket.SOCK_STREAM) + s.settimeout(1) + self.assertEqual(s.type, socket.SOCK_STREAM) + s.settimeout(0) + self.assertEqual(s.type, socket.SOCK_STREAM) + s.setblocking(True) + self.assertEqual(s.type, socket.SOCK_STREAM) + s.setblocking(False) + self.assertEqual(s.type, socket.SOCK_STREAM) + @unittest.skipIf(os.name == 'nt', 'Will not work on Windows') def test_uknown_socket_family_repr(self): # Test that when created with a family that's not one of the known @@ -1589,9 +1605,18 @@ class GeneralModuleTests(unittest.TestCase): # On Windows this trick won't work, so the test is skipped. fd, path = tempfile.mkstemp() self.addCleanup(os.unlink, path) - with socket.socket(family=42424, type=13331, fileno=fd) as s: - self.assertEqual(s.family, 42424) - self.assertEqual(s.type, 13331) + unknown_family = max(socket.AddressFamily.__members__.values()) + 1 + + unknown_type = max( + kind + for name, kind in socket.SocketKind.__members__.items() + if name not in {'SOCK_NONBLOCK', 'SOCK_CLOEXEC'} + ) + 1 + + with socket.socket( + family=unknown_family, type=unknown_type, fileno=fd) as s: + self.assertEqual(s.family, unknown_family) + self.assertEqual(s.type, unknown_type) @unittest.skipUnless(hasattr(os, 'sendfile'), 'test needs os.sendfile()') def test__sendfile_use_sendfile(self): @@ -5084,7 +5109,7 @@ class InheritanceTest(unittest.TestCase): def test_SOCK_CLOEXEC(self): with socket.socket(socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s: - self.assertTrue(s.type & socket.SOCK_CLOEXEC) + self.assertEqual(s.type, socket.SOCK_STREAM) self.assertFalse(s.get_inheritable()) def test_default_inheritable(self): @@ -5149,11 +5174,15 @@ class InheritanceTest(unittest.TestCase): class NonblockConstantTest(unittest.TestCase): def checkNonblock(self, s, nonblock=True, timeout=0.0): if nonblock: - self.assertTrue(s.type & socket.SOCK_NONBLOCK) + self.assertEqual(s.type, socket.SOCK_STREAM) self.assertEqual(s.gettimeout(), timeout) + self.assertTrue( + fcntl.fcntl(s, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK) else: - self.assertFalse(s.type & socket.SOCK_NONBLOCK) + self.assertEqual(s.type, socket.SOCK_STREAM) self.assertEqual(s.gettimeout(), None) + self.assertFalse( + fcntl.fcntl(s, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK) @support.requires_linux_version(2, 6, 28) def test_SOCK_NONBLOCK(self): |