summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/pty.py16
-rw-r--r--Lib/test/test_pty.py91
2 files changed, 101 insertions, 6 deletions
diff --git a/Lib/pty.py b/Lib/pty.py
index 810ebd8..3ccf619 100644
--- a/Lib/pty.py
+++ b/Lib/pty.py
@@ -142,15 +142,21 @@ def _copy(master_fd, master_read=_read, stdin_read=_read):
Copies
pty master -> standard output (master_read)
standard input -> pty master (stdin_read)"""
- while 1:
- rfds, wfds, xfds = select(
- [master_fd, STDIN_FILENO], [], [])
+ fds = [master_fd, STDIN_FILENO]
+ while True:
+ rfds, wfds, xfds = select(fds, [], [])
if master_fd in rfds:
data = master_read(master_fd)
- os.write(STDOUT_FILENO, data)
+ if not data: # Reached EOF.
+ fds.remove(master_fd)
+ else:
+ os.write(STDOUT_FILENO, data)
if STDIN_FILENO in rfds:
data = stdin_read(STDIN_FILENO)
- _writen(master_fd, data)
+ if not data:
+ fds.remove(STDIN_FILENO)
+ else:
+ _writen(master_fd, data)
def spawn(argv, master_read=_read, stdin_read=_read):
"""Create a spawned process."""
diff --git a/Lib/test/test_pty.py b/Lib/test/test_pty.py
index c6fc5e7..4f1251c 100644
--- a/Lib/test/test_pty.py
+++ b/Lib/test/test_pty.py
@@ -8,7 +8,9 @@ import errno
import pty
import os
import sys
+import select
import signal
+import socket
import unittest
TEST_STRING_1 = b"I wish to buy a fish license.\n"
@@ -194,9 +196,96 @@ class PtyTest(unittest.TestCase):
# pty.fork() passed.
+
+class SmallPtyTests(unittest.TestCase):
+ """These tests don't spawn children or hang."""
+
+ def setUp(self):
+ self.orig_stdin_fileno = pty.STDIN_FILENO
+ self.orig_stdout_fileno = pty.STDOUT_FILENO
+ self.orig_pty_select = pty.select
+ self.fds = [] # A list of file descriptors to close.
+ self.select_rfds_lengths = []
+ self.select_rfds_results = []
+
+ def tearDown(self):
+ pty.STDIN_FILENO = self.orig_stdin_fileno
+ pty.STDOUT_FILENO = self.orig_stdout_fileno
+ pty.select = self.orig_pty_select
+ for fd in self.fds:
+ try:
+ os.close(fd)
+ except:
+ pass
+
+ def _pipe(self):
+ pipe_fds = os.pipe()
+ self.fds.extend(pipe_fds)
+ return pipe_fds
+
+ def _mock_select(self, rfds, wfds, xfds):
+ # This will raise IndexError when no more expected calls exist.
+ self.assertEqual(self.select_rfds_lengths.pop(0), len(rfds))
+ return self.select_rfds_results.pop(0), [], []
+
+ def test__copy_to_each(self):
+ """Test the normal data case on both master_fd and stdin."""
+ read_from_stdout_fd, mock_stdout_fd = self._pipe()
+ pty.STDOUT_FILENO = mock_stdout_fd
+ mock_stdin_fd, write_to_stdin_fd = self._pipe()
+ pty.STDIN_FILENO = mock_stdin_fd
+ socketpair = socket.socketpair()
+ masters = [s.fileno() for s in socketpair]
+ self.fds.extend(masters)
+
+ # Feed data. Smaller than PIPEBUF. These writes will not block.
+ os.write(masters[1], b'from master')
+ os.write(write_to_stdin_fd, b'from stdin')
+
+ # Expect two select calls, the last one will cause IndexError
+ pty.select = self._mock_select
+ self.select_rfds_lengths.append(2)
+ self.select_rfds_results.append([mock_stdin_fd, masters[0]])
+ self.select_rfds_lengths.append(2)
+
+ with self.assertRaises(IndexError):
+ pty._copy(masters[0])
+
+ # Test that the right data went to the right places.
+ rfds = select.select([read_from_stdout_fd, masters[1]], [], [], 0)[0]
+ self.assertEqual([read_from_stdout_fd, masters[1]], rfds)
+ self.assertEqual(os.read(read_from_stdout_fd, 20), b'from master')
+ self.assertEqual(os.read(masters[1], 20), b'from stdin')
+
+ def test__copy_eof_on_all(self):
+ """Test the empty read EOF case on both master_fd and stdin."""
+ read_from_stdout_fd, mock_stdout_fd = self._pipe()
+ pty.STDOUT_FILENO = mock_stdout_fd
+ mock_stdin_fd, write_to_stdin_fd = self._pipe()
+ pty.STDIN_FILENO = mock_stdin_fd
+ socketpair = socket.socketpair()
+ masters = [s.fileno() for s in socketpair]
+ self.fds.extend(masters)
+
+ os.close(masters[1])
+ socketpair[1].close()
+ os.close(write_to_stdin_fd)
+
+ # Expect two select calls, the last one will cause IndexError
+ pty.select = self._mock_select
+ self.select_rfds_lengths.append(2)
+ self.select_rfds_results.append([mock_stdin_fd, masters[0]])
+ # We expect that both fds were removed from the fds list as they
+ # both encountered an EOF before the second select call.
+ self.select_rfds_lengths.append(0)
+
+ with self.assertRaises(IndexError):
+ pty._copy(masters[0])
+
+
def test_main(verbose=None):
try:
- run_unittest(PtyTest)
+ run_unittest(SmallPtyTests, PtyTest)
finally:
reap_children()