From c1baa601e2b558deb690edfdf334fceee3b03327 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Fri, 8 Jan 2010 17:54:23 +0000 Subject: Issue #7105: Make WeakKeyDictionary and WeakValueDictionary robust against the destruction of weakref'ed objects while iterating. --- Doc/library/weakref.rst | 4 +- Lib/_weakrefset.py | 74 ++++++++++++++++++++++++-- Lib/test/test_weakref.py | 87 ++++++++++++++++++++++++++++++ Lib/test/test_weakset.py | 50 ++++++++++++++++++ Lib/weakref.py | 134 ++++++++++++++++++++++++++++++----------------- Misc/NEWS | 3 ++ 6 files changed, 296 insertions(+), 56 deletions(-) diff --git a/Doc/library/weakref.rst b/Doc/library/weakref.rst index 2aa49e3..53b13e5 100644 --- a/Doc/library/weakref.rst +++ b/Doc/library/weakref.rst @@ -159,7 +159,7 @@ than needed. .. method:: WeakKeyDictionary.keyrefs() - Return an :term:`iterator` that yields the weak references to the keys. + Return an iterable of the weak references to the keys. .. class:: WeakValueDictionary([dict]) @@ -182,7 +182,7 @@ These method have the same issues as the and :meth:`keyrefs` method of .. method:: WeakValueDictionary.valuerefs() - Return an :term:`iterator` that yields the weak references to the values. + Return an iterable of the weak references to the values. .. class:: WeakSet([elements]) diff --git a/Lib/_weakrefset.py b/Lib/_weakrefset.py index addc7af..3de3bda 100644 --- a/Lib/_weakrefset.py +++ b/Lib/_weakrefset.py @@ -6,22 +6,61 @@ from _weakref import ref __all__ = ['WeakSet'] + +class _IterationGuard: + # This context manager registers itself in the current iterators of the + # weak container, such as to delay all removals until the context manager + # exits. + # This technique should be relatively thread-safe (since sets are). + + def __init__(self, weakcontainer): + # Don't create cycles + self.weakcontainer = ref(weakcontainer) + + def __enter__(self): + w = self.weakcontainer() + if w is not None: + w._iterating.add(self) + return self + + def __exit__(self, e, t, b): + w = self.weakcontainer() + if w is not None: + s = w._iterating + s.remove(self) + if not s: + w._commit_removals() + + class WeakSet: def __init__(self, data=None): self.data = set() def _remove(item, selfref=ref(self)): self = selfref() if self is not None: - self.data.discard(item) + if self._iterating: + self._pending_removals.append(item) + else: + self.data.discard(item) self._remove = _remove + # A list of keys to be removed + self._pending_removals = [] + self._iterating = set() if data is not None: self.update(data) + def _commit_removals(self): + l = self._pending_removals + discard = self.data.discard + while l: + discard(l.pop()) + def __iter__(self): - for itemref in self.data: - item = itemref() - if item is not None: - yield item + with _IterationGuard(self): + for itemref in self.data: + item = itemref() + if item is not None: + yield item def __len__(self): return sum(x() is not None for x in self.data) @@ -34,15 +73,21 @@ class WeakSet: getattr(self, '__dict__', None)) def add(self, item): + if self._pending_removals: + self._commit_removals() self.data.add(ref(item, self._remove)) def clear(self): + if self._pending_removals: + self._commit_removals() self.data.clear() def copy(self): return self.__class__(self) def pop(self): + if self._pending_removals: + self._commit_removals() while True: try: itemref = self.data.pop() @@ -53,17 +98,24 @@ class WeakSet: return item def remove(self, item): + if self._pending_removals: + self._commit_removals() self.data.remove(ref(item)) def discard(self, item): + if self._pending_removals: + self._commit_removals() self.data.discard(ref(item)) def update(self, other): + if self._pending_removals: + self._commit_removals() if isinstance(other, self.__class__): self.data.update(other.data) else: for element in other: self.add(element) + def __ior__(self, other): self.update(other) return self @@ -82,11 +134,15 @@ class WeakSet: __sub__ = difference def difference_update(self, other): + if self._pending_removals: + self._commit_removals() if self is other: self.data.clear() else: self.data.difference_update(ref(item) for item in other) def __isub__(self, other): + if self._pending_removals: + self._commit_removals() if self is other: self.data.clear() else: @@ -98,8 +154,12 @@ class WeakSet: __and__ = intersection def intersection_update(self, other): + if self._pending_removals: + self._commit_removals() self.data.intersection_update(ref(item) for item in other) def __iand__(self, other): + if self._pending_removals: + self._commit_removals() self.data.intersection_update(ref(item) for item in other) return self @@ -127,11 +187,15 @@ class WeakSet: __xor__ = symmetric_difference def symmetric_difference_update(self, other): + if self._pending_removals: + self._commit_removals() if self is other: self.data.clear() else: self.data.symmetric_difference_update(ref(item) for item in other) def __ixor__(self, other): + if self._pending_removals: + self._commit_removals() if self is other: self.data.clear() else: diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py index ecf1976..028b418 100644 --- a/Lib/test/test_weakref.py +++ b/Lib/test/test_weakref.py @@ -4,6 +4,8 @@ import unittest import collections import weakref import operator +import contextlib +import copy from test import support @@ -788,6 +790,10 @@ class Object: self.arg = arg def __repr__(self): return "" % self.arg + def __eq__(self, other): + if isinstance(other, Object): + return self.arg == other.arg + return NotImplemented def __lt__(self, other): if isinstance(other, Object): return self.arg < other.arg @@ -935,6 +941,87 @@ class MappingTestCase(TestBase): self.assertFalse(values, "itervalues() did not touch all values") + def check_weak_destroy_while_iterating(self, dict, objects, iter_name): + n = len(dict) + it = iter(getattr(dict, iter_name)()) + next(it) # Trigger internal iteration + # Destroy an object + del objects[-1] + gc.collect() # just in case + # We have removed either the first consumed object, or another one + self.assertIn(len(list(it)), [len(objects), len(objects) - 1]) + del it + # The removal has been committed + self.assertEqual(len(dict), n - 1) + + def check_weak_destroy_and_mutate_while_iterating(self, dict, testcontext): + # Check that we can explicitly mutate the weak dict without + # interfering with delayed removal. + # `testcontext` should create an iterator, destroy one of the + # weakref'ed objects and then return a new key/value pair corresponding + # to the destroyed object. + with testcontext() as (k, v): + self.assertFalse(k in dict) + with testcontext() as (k, v): + self.assertRaises(KeyError, dict.__delitem__, k) + self.assertFalse(k in dict) + with testcontext() as (k, v): + self.assertRaises(KeyError, dict.pop, k) + self.assertFalse(k in dict) + with testcontext() as (k, v): + dict[k] = v + self.assertEqual(dict[k], v) + ddict = copy.copy(dict) + with testcontext() as (k, v): + dict.update(ddict) + self.assertEqual(dict, ddict) + with testcontext() as (k, v): + dict.clear() + self.assertEqual(len(dict), 0) + + def test_weak_keys_destroy_while_iterating(self): + # Issue #7105: iterators shouldn't crash when a key is implicitly removed + dict, objects = self.make_weak_keyed_dict() + self.check_weak_destroy_while_iterating(dict, objects, 'keys') + self.check_weak_destroy_while_iterating(dict, objects, 'items') + self.check_weak_destroy_while_iterating(dict, objects, 'values') + self.check_weak_destroy_while_iterating(dict, objects, 'keyrefs') + dict, objects = self.make_weak_keyed_dict() + @contextlib.contextmanager + def testcontext(): + try: + it = iter(dict.items()) + next(it) + # Schedule a key/value for removal and recreate it + v = objects.pop().arg + gc.collect() # just in case + yield Object(v), v + finally: + it = None # should commit all removals + self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext) + + def test_weak_values_destroy_while_iterating(self): + # Issue #7105: iterators shouldn't crash when a key is implicitly removed + dict, objects = self.make_weak_valued_dict() + self.check_weak_destroy_while_iterating(dict, objects, 'keys') + self.check_weak_destroy_while_iterating(dict, objects, 'items') + self.check_weak_destroy_while_iterating(dict, objects, 'values') + self.check_weak_destroy_while_iterating(dict, objects, 'itervaluerefs') + self.check_weak_destroy_while_iterating(dict, objects, 'valuerefs') + dict, objects = self.make_weak_valued_dict() + @contextlib.contextmanager + def testcontext(): + try: + it = iter(dict.items()) + next(it) + # Schedule a key/value for removal and recreate it + k = objects.pop().arg + gc.collect() # just in case + yield k, Object(k) + finally: + it = None # should commit all removals + self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext) + def test_make_weak_keyed_dict_from_dict(self): o = Object(3) dict = weakref.WeakKeyDictionary({o:364}) diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py index 651efe2..4e0aa38 100644 --- a/Lib/test/test_weakset.py +++ b/Lib/test/test_weakset.py @@ -10,6 +10,8 @@ import sys import warnings import collections from collections import UserString as ustr +import gc +import contextlib class Foo: @@ -307,6 +309,54 @@ class TestWeakSet(unittest.TestCase): self.assertFalse(self.s == WeakSet([Foo])) self.assertFalse(self.s == 1) + def test_weak_destroy_while_iterating(self): + # Issue #7105: iterators shouldn't crash when a key is implicitly removed + # Create new items to be sure no-one else holds a reference + items = [ustr(c) for c in ('a', 'b', 'c')] + s = WeakSet(items) + it = iter(s) + next(it) # Trigger internal iteration + # Destroy an item + del items[-1] + gc.collect() # just in case + # We have removed either the first consumed items, or another one + self.assertIn(len(list(it)), [len(items), len(items) - 1]) + del it + # The removal has been committed + self.assertEqual(len(s), len(items)) + + def test_weak_destroy_and_mutate_while_iterating(self): + # Issue #7105: iterators shouldn't crash when a key is implicitly removed + items = [ustr(c) for c in string.ascii_letters] + s = WeakSet(items) + @contextlib.contextmanager + def testcontext(): + try: + it = iter(s) + next(it) + # Schedule an item for removal and recreate it + u = ustr(str(items.pop())) + gc.collect() # just in case + yield u + finally: + it = None # should commit all removals + + with testcontext() as u: + self.assertFalse(u in s) + with testcontext() as u: + self.assertRaises(KeyError, s.remove, u) + self.assertFalse(u in s) + with testcontext() as u: + s.add(u) + self.assertTrue(u in s) + t = s.copy() + with testcontext() as u: + s.update(t) + self.assertEqual(len(s), len(t)) + with testcontext() as u: + s.clear() + self.assertEqual(len(s), 0) + def test_main(verbose=None): support.run_unittest(TestWeakSet) diff --git a/Lib/weakref.py b/Lib/weakref.py index 5e6cc8b..66c4dc6 100644 --- a/Lib/weakref.py +++ b/Lib/weakref.py @@ -18,7 +18,7 @@ from _weakref import ( ProxyType, ReferenceType) -from _weakrefset import WeakSet +from _weakrefset import WeakSet, _IterationGuard import collections # Import after _weakref to avoid circular import. @@ -46,11 +46,25 @@ class WeakValueDictionary(collections.MutableMapping): def remove(wr, selfref=ref(self)): self = selfref() if self is not None: - del self.data[wr.key] + if self._iterating: + self._pending_removals.append(wr.key) + else: + del self.data[wr.key] self._remove = remove + # A list of keys to be removed + self._pending_removals = [] + self._iterating = set() self.data = d = {} self.update(*args, **kw) + def _commit_removals(self): + l = self._pending_removals + d = self.data + # We shouldn't encounter any KeyError, because this method should + # always be called *before* mutating the dict. + while l: + del d[l.pop()] + def __getitem__(self, key): o = self.data[key]() if o is None: @@ -59,6 +73,8 @@ class WeakValueDictionary(collections.MutableMapping): return o def __delitem__(self, key): + if self._pending_removals: + self._commit_removals() del self.data[key] def __len__(self): @@ -75,6 +91,8 @@ class WeakValueDictionary(collections.MutableMapping): return "" % id(self) def __setitem__(self, key, value): + if self._pending_removals: + self._commit_removals() self.data[key] = KeyedRef(value, self._remove, key) def copy(self): @@ -110,24 +128,19 @@ class WeakValueDictionary(collections.MutableMapping): return o def items(self): - L = [] - for key, wr in self.data.items(): - o = wr() - if o is not None: - L.append((key, o)) - return L - - def items(self): - for wr in self.data.values(): - value = wr() - if value is not None: - yield wr.key, value + with _IterationGuard(self): + for k, wr in self.data.items(): + v = wr() + if v is not None: + yield k, v def keys(self): - return iter(self.data.keys()) + with _IterationGuard(self): + for k, wr in self.data.items(): + if wr() is not None: + yield k - def __iter__(self): - return iter(self.data.keys()) + __iter__ = keys def itervaluerefs(self): """Return an iterator that yields the weak references to the values. @@ -139,15 +152,20 @@ class WeakValueDictionary(collections.MutableMapping): keep the values around longer than needed. """ - return self.data.values() + with _IterationGuard(self): + for wr in self.data.values(): + yield wr def values(self): - for wr in self.data.values(): - obj = wr() - if obj is not None: - yield obj + with _IterationGuard(self): + for wr in self.data.values(): + obj = wr() + if obj is not None: + yield obj def popitem(self): + if self._pending_removals: + self._commit_removals() while 1: key, wr = self.data.popitem() o = wr() @@ -155,6 +173,8 @@ class WeakValueDictionary(collections.MutableMapping): return key, o def pop(self, key, *args): + if self._pending_removals: + self._commit_removals() try: o = self.data.pop(key)() except KeyError: @@ -170,12 +190,16 @@ class WeakValueDictionary(collections.MutableMapping): try: wr = self.data[key] except KeyError: + if self._pending_removals: + self._commit_removals() self.data[key] = KeyedRef(default, self._remove, key) return default else: return wr() def update(self, dict=None, **kwargs): + if self._pending_removals: + self._commit_removals() d = self.data if dict is not None: if not hasattr(dict, "items"): @@ -195,7 +219,7 @@ class WeakValueDictionary(collections.MutableMapping): keep the values around longer than needed. """ - return self.data.values() + return list(self.data.values()) class KeyedRef(ref): @@ -235,9 +259,29 @@ class WeakKeyDictionary(collections.MutableMapping): def remove(k, selfref=ref(self)): self = selfref() if self is not None: - del self.data[k] + if self._iterating: + self._pending_removals.append(k) + else: + del self.data[k] self._remove = remove - if dict is not None: self.update(dict) + # A list of dead weakrefs (keys to be removed) + self._pending_removals = [] + self._iterating = set() + if dict is not None: + self.update(dict) + + def _commit_removals(self): + # NOTE: We don't need to call this method before mutating the dict, + # because a dead weakref never compares equal to a live weakref, + # even if they happened to refer to equal objects. + # However, it means keys may already have been removed. + l = self._pending_removals + d = self.data + while l: + try: + del d[l.pop()] + except KeyError: + pass def __delitem__(self, key): del self.data[ref(key)] @@ -284,34 +328,26 @@ class WeakKeyDictionary(collections.MutableMapping): return wr in self.data def items(self): - for wr, value in self.data.items(): - key = wr() - if key is not None: - yield key, value - - def keyrefs(self): - """Return an iterator that yields the weak references to the keys. - - The references are not guaranteed to be 'live' at the time - they are used, so the result of calling the references needs - to be checked before being used. This can be used to avoid - creating references that will cause the garbage collector to - keep the keys around longer than needed. - - """ - return self.data.keys() + with _IterationGuard(self): + for wr, value in self.data.items(): + key = wr() + if key is not None: + yield key, value def keys(self): - for wr in self.data.keys(): - obj = wr() - if obj is not None: - yield obj + with _IterationGuard(self): + for wr in self.data: + obj = wr() + if obj is not None: + yield obj - def __iter__(self): - return iter(self.keys()) + __iter__ = keys def values(self): - return iter(self.data.values()) + with _IterationGuard(self): + for wr, value in self.data.items(): + if wr() is not None: + yield value def keyrefs(self): """Return a list of weak references to the keys. @@ -323,7 +359,7 @@ class WeakKeyDictionary(collections.MutableMapping): keep the keys around longer than needed. """ - return self.data.keys() + return list(self.data) def popitem(self): while 1: diff --git a/Misc/NEWS b/Misc/NEWS index 23fec21..4cfafff 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -194,6 +194,9 @@ C-API Library ------- +- Issue #7105: Make WeakKeyDictionary and WeakValueDictionary robust against + the destruction of weakref'ed objects while iterating. + - Issue #7455: Fix possible crash in cPickle on invalid input. Patch by Victor Stinner. -- cgit v0.12