summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGregory P. Smith <greg@krypto.org>2012-07-16 06:42:26 (GMT)
committerGregory P. Smith <greg@krypto.org>2012-07-16 06:42:26 (GMT)
commitdad5711677f965a62713bfc9a34599c8f44572e8 (patch)
tree336b01e8761c7ff2baecefa3cc67f49a6e3a8245
parent4774946c3b7564b4a1c93da4ba4dba443a36a708 (diff)
downloadcpython-dad5711677f965a62713bfc9a34599c8f44572e8.zip
cpython-dad5711677f965a62713bfc9a34599c8f44572e8.tar.gz
cpython-dad5711677f965a62713bfc9a34599c8f44572e8.tar.bz2
Fixes Issue #14635: telnetlib will use poll() rather than select() when possible
to avoid failing due to the select() file descriptor limit.
-rw-r--r--Lib/telnetlib.py130
-rw-r--r--Lib/test/test_telnetlib.py96
-rw-r--r--Misc/ACKS1
-rw-r--r--Misc/NEWS3
4 files changed, 223 insertions, 7 deletions
diff --git a/Lib/telnetlib.py b/Lib/telnetlib.py
index 82b5e8f..a59693e 100644
--- a/Lib/telnetlib.py
+++ b/Lib/telnetlib.py
@@ -34,6 +34,7 @@ To do:
# Imported modules
+import errno
import sys
import socket
import select
@@ -205,6 +206,7 @@ class Telnet:
self.sb = 0 # flag for SB and SE sequence.
self.sbdataq = b''
self.option_callback = None
+ self._has_poll = hasattr(select, 'poll')
if host is not None:
self.open(host, port, timeout)
@@ -287,6 +289,61 @@ class Telnet:
is closed and no cooked data is available.
"""
+ if self._has_poll:
+ return self._read_until_with_poll(match, timeout)
+ else:
+ return self._read_until_with_select(match, timeout)
+
+ def _read_until_with_poll(self, match, timeout):
+ """Read until a given string is encountered or until timeout.
+
+ This method uses select.poll() to implement the timeout.
+ """
+ n = len(match)
+ call_timeout = timeout
+ if timeout is not None:
+ from time import time
+ time_start = time()
+ self.process_rawq()
+ i = self.cookedq.find(match)
+ if i < 0:
+ poller = select.poll()
+ poll_in_or_priority_flags = select.POLLIN | select.POLLPRI
+ poller.register(self, poll_in_or_priority_flags)
+ while i < 0 and not self.eof:
+ try:
+ ready = poller.poll(call_timeout)
+ except select.error as e:
+ if e.errno == errno.EINTR:
+ if timeout is not None:
+ elapsed = time() - time_start
+ call_timeout = timeout-elapsed
+ continue
+ raise
+ for fd, mode in ready:
+ if mode & poll_in_or_priority_flags:
+ i = max(0, len(self.cookedq)-n)
+ self.fill_rawq()
+ self.process_rawq()
+ i = self.cookedq.find(match, i)
+ if timeout is not None:
+ elapsed = time() - time_start
+ if elapsed >= timeout:
+ break
+ call_timeout = timeout-elapsed
+ poller.unregister(self)
+ if i >= 0:
+ i = i + n
+ buf = self.cookedq[:i]
+ self.cookedq = self.cookedq[i:]
+ return buf
+ return self.read_very_lazy()
+
+ def _read_until_with_select(self, match, timeout=None):
+ """Read until a given string is encountered or until timeout.
+
+ The timeout is implemented using select.select().
+ """
n = len(match)
self.process_rawq()
i = self.cookedq.find(match)
@@ -589,6 +646,79 @@ class Telnet:
results are undeterministic, and may depend on the I/O timing.
"""
+ if self._has_poll:
+ return self._expect_with_poll(list, timeout)
+ else:
+ return self._expect_with_select(list, timeout)
+
+ def _expect_with_poll(self, expect_list, timeout=None):
+ """Read until one from a list of a regular expressions matches.
+
+ This method uses select.poll() to implement the timeout.
+ """
+ re = None
+ expect_list = expect_list[:]
+ indices = range(len(expect_list))
+ for i in indices:
+ if not hasattr(expect_list[i], "search"):
+ if not re: import re
+ expect_list[i] = re.compile(expect_list[i])
+ call_timeout = timeout
+ if timeout is not None:
+ from time import time
+ time_start = time()
+ self.process_rawq()
+ m = None
+ for i in indices:
+ m = expect_list[i].search(self.cookedq)
+ if m:
+ e = m.end()
+ text = self.cookedq[:e]
+ self.cookedq = self.cookedq[e:]
+ break
+ if not m:
+ poller = select.poll()
+ poll_in_or_priority_flags = select.POLLIN | select.POLLPRI
+ poller.register(self, poll_in_or_priority_flags)
+ while not m and not self.eof:
+ try:
+ ready = poller.poll(call_timeout)
+ except select.error as e:
+ if e.errno == errno.EINTR:
+ if timeout is not None:
+ elapsed = time() - time_start
+ call_timeout = timeout-elapsed
+ continue
+ raise
+ for fd, mode in ready:
+ if mode & poll_in_or_priority_flags:
+ self.fill_rawq()
+ self.process_rawq()
+ for i in indices:
+ m = expect_list[i].search(self.cookedq)
+ if m:
+ e = m.end()
+ text = self.cookedq[:e]
+ self.cookedq = self.cookedq[e:]
+ break
+ if timeout is not None:
+ elapsed = time() - time_start
+ if elapsed >= timeout:
+ break
+ call_timeout = timeout-elapsed
+ poller.unregister(self)
+ if m:
+ return (i, m, text)
+ text = self.read_very_lazy()
+ if not text and self.eof:
+ raise EOFError
+ return (-1, None, text)
+
+ def _expect_with_select(self, list, timeout=None):
+ """Read until one from a list of a regular expressions matches.
+
+ The timeout is implemented using select.select().
+ """
re = None
list = list[:]
indices = range(len(list))
diff --git a/Lib/test/test_telnetlib.py b/Lib/test/test_telnetlib.py
index 87418f5..38da08c 100644
--- a/Lib/test/test_telnetlib.py
+++ b/Lib/test/test_telnetlib.py
@@ -75,8 +75,8 @@ class GeneralTests(TestCase):
class SocketStub(object):
''' a socket proxy that re-defines sendall() '''
- def __init__(self, reads=[]):
- self.reads = reads
+ def __init__(self, reads=()):
+ self.reads = list(reads) # Intentionally make a copy.
self.writes = []
self.block = False
def sendall(self, data):
@@ -102,7 +102,7 @@ class TelnetAlike(telnetlib.Telnet):
self._messages += out.getvalue()
return
-def new_select(*s_args):
+def mock_select(*s_args):
block = False
for l in s_args:
for fob in l:
@@ -113,6 +113,30 @@ def new_select(*s_args):
else:
return s_args
+class MockPoller(object):
+ test_case = None # Set during TestCase setUp.
+
+ def __init__(self):
+ self._file_objs = []
+
+ def register(self, fd, eventmask):
+ self.test_case.assertTrue(hasattr(fd, 'fileno'), fd)
+ self.test_case.assertEqual(eventmask, select.POLLIN|select.POLLPRI)
+ self._file_objs.append(fd)
+
+ def poll(self, timeout=None):
+ block = False
+ for fob in self._file_objs:
+ if isinstance(fob, TelnetAlike):
+ block = fob.sock.block
+ if block:
+ return []
+ else:
+ return zip(self._file_objs, [select.POLLIN]*len(self._file_objs))
+
+ def unregister(self, fd):
+ self._file_objs.remove(fd)
+
@contextlib.contextmanager
def test_socket(reads):
def new_conn(*ignored):
@@ -125,7 +149,7 @@ def test_socket(reads):
socket.create_connection = old_conn
return
-def test_telnet(reads=[], cls=TelnetAlike):
+def test_telnet(reads=(), cls=TelnetAlike, use_poll=None):
''' return a telnetlib.Telnet object that uses a SocketStub with
reads queued up to be read '''
for x in reads:
@@ -133,15 +157,28 @@ def test_telnet(reads=[], cls=TelnetAlike):
with test_socket(reads):
telnet = cls('dummy', 0)
telnet._messages = '' # debuglevel output
+ if use_poll is not None:
+ if use_poll and not telnet._has_poll:
+ raise unittest.SkipTest('select.poll() required.')
+ telnet._has_poll = use_poll
return telnet
-class ReadTests(TestCase):
+
+class ExpectAndReadTestCase(TestCase):
def setUp(self):
self.old_select = select.select
- select.select = new_select
+ self.old_poll = select.poll
+ select.select = mock_select
+ select.poll = MockPoller
+ MockPoller.test_case = self
+
def tearDown(self):
+ MockPoller.test_case = None
+ select.poll = self.old_poll
select.select = self.old_select
+
+class ReadTests(ExpectAndReadTestCase):
def test_read_until(self):
"""
read_until(expected, timeout=None)
@@ -158,6 +195,21 @@ class ReadTests(TestCase):
data = telnet.read_until(b'match')
self.assertEqual(data, expect)
+ def test_read_until_with_poll(self):
+ """Use select.poll() to implement telnet.read_until()."""
+ want = [b'x' * 10, b'match', b'y' * 10]
+ telnet = test_telnet(want, use_poll=True)
+ select.select = lambda *_: self.fail('unexpected select() call.')
+ data = telnet.read_until(b'match')
+ self.assertEqual(data, b''.join(want[:-1]))
+
+ def test_read_until_with_select(self):
+ """Use select.select() to implement telnet.read_until()."""
+ want = [b'x' * 10, b'match', b'y' * 10]
+ telnet = test_telnet(want, use_poll=False)
+ select.poll = lambda *_: self.fail('unexpected poll() call.')
+ data = telnet.read_until(b'match')
+ self.assertEqual(data, b''.join(want[:-1]))
def test_read_all(self):
"""
@@ -349,8 +401,38 @@ class OptionTests(TestCase):
self.assertRegex(telnet._messages, r'0.*test')
+class ExpectTests(ExpectAndReadTestCase):
+ def test_expect(self):
+ """
+ expect(expected, [timeout])
+ Read until the expected string has been seen, or a timeout is
+ hit (default is no timeout); may block.
+ """
+ want = [b'x' * 10, b'match', b'y' * 10]
+ telnet = test_telnet(want)
+ (_,_,data) = telnet.expect([b'match'])
+ self.assertEqual(data, b''.join(want[:-1]))
+
+ def test_expect_with_poll(self):
+ """Use select.poll() to implement telnet.expect()."""
+ want = [b'x' * 10, b'match', b'y' * 10]
+ telnet = test_telnet(want, use_poll=True)
+ select.select = lambda *_: self.fail('unexpected select() call.')
+ (_,_,data) = telnet.expect([b'match'])
+ self.assertEqual(data, b''.join(want[:-1]))
+
+ def test_expect_with_select(self):
+ """Use select.select() to implement telnet.expect()."""
+ want = [b'x' * 10, b'match', b'y' * 10]
+ telnet = test_telnet(want, use_poll=False)
+ select.poll = lambda *_: self.fail('unexpected poll() call.')
+ (_,_,data) = telnet.expect([b'match'])
+ self.assertEqual(data, b''.join(want[:-1]))
+
+
def test_main(verbose=None):
- support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests)
+ support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests,
+ ExpectTests)
if __name__ == '__main__':
test_main()
diff --git a/Misc/ACKS b/Misc/ACKS
index 26fe1d0..3bf81a2 100644
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -410,6 +410,7 @@ Chris Hoffman
Albert Hofkamp
Tomas Hoger
Jonathan Hogg
+Akintayo Holder
Gerrit Holl
Shane Holloway
Rune Holm
diff --git a/Misc/NEWS b/Misc/NEWS
index 74e4038..1d5353a 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -87,6 +87,9 @@ Core and Builtins
Library
-------
+- Issue #14635: telnetlib will use poll() rather than select() when possible
+ to avoid failing due to the select() file descriptor limit.
+
- Issue #15180: Clarify posixpath.join() error message when mixing str & bytes
- Issue #15230: runpy.run_path now correctly sets __package__ as described