summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/selectors.py30
-rw-r--r--Lib/test/test_selectors.py27
2 files changed, 56 insertions, 1 deletions
diff --git a/Lib/selectors.py b/Lib/selectors.py
index f29d11f..edde22c 100644
--- a/Lib/selectors.py
+++ b/Lib/selectors.py
@@ -252,7 +252,6 @@ class _BaseSelectorImpl(BaseSelector):
return key
def modify(self, fileobj, events, data=None):
- # TODO: Subclasses can probably optimize this even further.
try:
key = self._fd_to_key[self._fileobj_lookup(fileobj)]
except KeyError:
@@ -342,6 +341,8 @@ class SelectSelector(_BaseSelectorImpl):
class _PollLikeSelector(_BaseSelectorImpl):
"""Base class shared between poll, epoll and devpoll selectors."""
_selector_cls = None
+ _EVENT_READ = None
+ _EVENT_WRITE = None
def __init__(self):
super().__init__()
@@ -371,6 +372,33 @@ class _PollLikeSelector(_BaseSelectorImpl):
pass
return key
+ def modify(self, fileobj, events, data=None):
+ try:
+ key = self._fd_to_key[self._fileobj_lookup(fileobj)]
+ except KeyError:
+ raise KeyError(f"{fileobj!r} is not registered") from None
+
+ changed = False
+ if events != key.events:
+ selector_events = 0
+ if events & EVENT_READ:
+ selector_events |= self._EVENT_READ
+ if events & EVENT_WRITE:
+ selector_events |= self._EVENT_WRITE
+ try:
+ self._selector.modify(key.fd, selector_events)
+ except Exception:
+ super().unregister(fileobj)
+ raise
+ changed = True
+ if data != key.data:
+ changed = True
+
+ if changed:
+ key = key._replace(events=events, data=data)
+ self._fd_to_key[key.fd] = key
+ return key
+
def select(self, timeout=None):
# This is shared between poll() and epoll().
# epoll() has a different signature and handling of timeout parameter.
diff --git a/Lib/test/test_selectors.py b/Lib/test/test_selectors.py
index 852b2fe..f2594a6 100644
--- a/Lib/test/test_selectors.py
+++ b/Lib/test/test_selectors.py
@@ -175,6 +175,33 @@ class BaseSelectorTestCase(unittest.TestCase):
self.assertFalse(s.register.called)
self.assertFalse(s.unregister.called)
+ def test_modify_unregister(self):
+ # Make sure the fd is unregister()ed in case of error on
+ # modify(): http://bugs.python.org/issue30014
+ if self.SELECTOR.__name__ == 'EpollSelector':
+ patch = unittest.mock.patch(
+ 'selectors.EpollSelector._selector_cls')
+ elif self.SELECTOR.__name__ == 'PollSelector':
+ patch = unittest.mock.patch(
+ 'selectors.PollSelector._selector_cls')
+ elif self.SELECTOR.__name__ == 'DevpollSelector':
+ patch = unittest.mock.patch(
+ 'selectors.DevpollSelector._selector_cls')
+ else:
+ raise self.skipTest("")
+
+ with patch as m:
+ m.return_value.modify = unittest.mock.Mock(
+ side_effect=ZeroDivisionError)
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+ rd, wr = self.make_socketpair()
+ s.register(rd, selectors.EVENT_READ)
+ self.assertEqual(len(s._map), 1)
+ with self.assertRaises(ZeroDivisionError):
+ s.modify(rd, selectors.EVENT_WRITE)
+ self.assertEqual(len(s._map), 0)
+
def test_close(self):
s = self.SELECTOR()
self.addCleanup(s.close)