summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorAntoine Pitrou <solipsis@pitrou.net>2010-01-08 17:54:23 (GMT)
committerAntoine Pitrou <solipsis@pitrou.net>2010-01-08 17:54:23 (GMT)
commitc1baa601e2b558deb690edfdf334fceee3b03327 (patch)
tree1cf896c04e483406149bb8ad9c47ce89271a3795 /Lib
parentdc2a61347b569a42f99b7f64fd59bff0d2dcb4ba (diff)
downloadcpython-c1baa601e2b558deb690edfdf334fceee3b03327.zip
cpython-c1baa601e2b558deb690edfdf334fceee3b03327.tar.gz
cpython-c1baa601e2b558deb690edfdf334fceee3b03327.tar.bz2
Issue #7105: Make WeakKeyDictionary and WeakValueDictionary robust against
the destruction of weakref'ed objects while iterating.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/_weakrefset.py74
-rw-r--r--Lib/test/test_weakref.py87
-rw-r--r--Lib/test/test_weakset.py50
-rw-r--r--Lib/weakref.py134
4 files changed, 291 insertions, 54 deletions
diff --git a/Lib/_weakrefset.py b/Lib/_weakrefset.py
index addc7af..3de3bda 100644
--- a/Lib/_weakrefset.py
+++ b/Lib/_weakrefset.py
@@ -6,22 +6,61 @@ from _weakref import ref
__all__ = ['WeakSet']
+
+class _IterationGuard:
+ # This context manager registers itself in the current iterators of the
+ # weak container, such as to delay all removals until the context manager
+ # exits.
+ # This technique should be relatively thread-safe (since sets are).
+
+ def __init__(self, weakcontainer):
+ # Don't create cycles
+ self.weakcontainer = ref(weakcontainer)
+
+ def __enter__(self):
+ w = self.weakcontainer()
+ if w is not None:
+ w._iterating.add(self)
+ return self
+
+ def __exit__(self, e, t, b):
+ w = self.weakcontainer()
+ if w is not None:
+ s = w._iterating
+ s.remove(self)
+ if not s:
+ w._commit_removals()
+
+
class WeakSet:
def __init__(self, data=None):
self.data = set()
def _remove(item, selfref=ref(self)):
self = selfref()
if self is not None:
- self.data.discard(item)
+ if self._iterating:
+ self._pending_removals.append(item)
+ else:
+ self.data.discard(item)
self._remove = _remove
+ # A list of keys to be removed
+ self._pending_removals = []
+ self._iterating = set()
if data is not None:
self.update(data)
+ def _commit_removals(self):
+ l = self._pending_removals
+ discard = self.data.discard
+ while l:
+ discard(l.pop())
+
def __iter__(self):
- for itemref in self.data:
- item = itemref()
- if item is not None:
- yield item
+ with _IterationGuard(self):
+ for itemref in self.data:
+ item = itemref()
+ if item is not None:
+ yield item
def __len__(self):
return sum(x() is not None for x in self.data)
@@ -34,15 +73,21 @@ class WeakSet:
getattr(self, '__dict__', None))
def add(self, item):
+ if self._pending_removals:
+ self._commit_removals()
self.data.add(ref(item, self._remove))
def clear(self):
+ if self._pending_removals:
+ self._commit_removals()
self.data.clear()
def copy(self):
return self.__class__(self)
def pop(self):
+ if self._pending_removals:
+ self._commit_removals()
while True:
try:
itemref = self.data.pop()
@@ -53,17 +98,24 @@ class WeakSet:
return item
def remove(self, item):
+ if self._pending_removals:
+ self._commit_removals()
self.data.remove(ref(item))
def discard(self, item):
+ if self._pending_removals:
+ self._commit_removals()
self.data.discard(ref(item))
def update(self, other):
+ if self._pending_removals:
+ self._commit_removals()
if isinstance(other, self.__class__):
self.data.update(other.data)
else:
for element in other:
self.add(element)
+
def __ior__(self, other):
self.update(other)
return self
@@ -82,11 +134,15 @@ class WeakSet:
__sub__ = difference
def difference_update(self, other):
+ if self._pending_removals:
+ self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.difference_update(ref(item) for item in other)
def __isub__(self, other):
+ if self._pending_removals:
+ self._commit_removals()
if self is other:
self.data.clear()
else:
@@ -98,8 +154,12 @@ class WeakSet:
__and__ = intersection
def intersection_update(self, other):
+ if self._pending_removals:
+ self._commit_removals()
self.data.intersection_update(ref(item) for item in other)
def __iand__(self, other):
+ if self._pending_removals:
+ self._commit_removals()
self.data.intersection_update(ref(item) for item in other)
return self
@@ -127,11 +187,15 @@ class WeakSet:
__xor__ = symmetric_difference
def symmetric_difference_update(self, other):
+ if self._pending_removals:
+ self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.symmetric_difference_update(ref(item) for item in other)
def __ixor__(self, other):
+ if self._pending_removals:
+ self._commit_removals()
if self is other:
self.data.clear()
else:
diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py
index ecf1976..028b418 100644
--- a/Lib/test/test_weakref.py
+++ b/Lib/test/test_weakref.py
@@ -4,6 +4,8 @@ import unittest
import collections
import weakref
import operator
+import contextlib
+import copy
from test import support
@@ -788,6 +790,10 @@ class Object:
self.arg = arg
def __repr__(self):
return "<Object %r>" % self.arg
+ def __eq__(self, other):
+ if isinstance(other, Object):
+ return self.arg == other.arg
+ return NotImplemented
def __lt__(self, other):
if isinstance(other, Object):
return self.arg < other.arg
@@ -935,6 +941,87 @@ class MappingTestCase(TestBase):
self.assertFalse(values,
"itervalues() did not touch all values")
+ def check_weak_destroy_while_iterating(self, dict, objects, iter_name):
+ n = len(dict)
+ it = iter(getattr(dict, iter_name)())
+ next(it) # Trigger internal iteration
+ # Destroy an object
+ del objects[-1]
+ gc.collect() # just in case
+ # We have removed either the first consumed object, or another one
+ self.assertIn(len(list(it)), [len(objects), len(objects) - 1])
+ del it
+ # The removal has been committed
+ self.assertEqual(len(dict), n - 1)
+
+ def check_weak_destroy_and_mutate_while_iterating(self, dict, testcontext):
+ # Check that we can explicitly mutate the weak dict without
+ # interfering with delayed removal.
+ # `testcontext` should create an iterator, destroy one of the
+ # weakref'ed objects and then return a new key/value pair corresponding
+ # to the destroyed object.
+ with testcontext() as (k, v):
+ self.assertFalse(k in dict)
+ with testcontext() as (k, v):
+ self.assertRaises(KeyError, dict.__delitem__, k)
+ self.assertFalse(k in dict)
+ with testcontext() as (k, v):
+ self.assertRaises(KeyError, dict.pop, k)
+ self.assertFalse(k in dict)
+ with testcontext() as (k, v):
+ dict[k] = v
+ self.assertEqual(dict[k], v)
+ ddict = copy.copy(dict)
+ with testcontext() as (k, v):
+ dict.update(ddict)
+ self.assertEqual(dict, ddict)
+ with testcontext() as (k, v):
+ dict.clear()
+ self.assertEqual(len(dict), 0)
+
+ def test_weak_keys_destroy_while_iterating(self):
+ # Issue #7105: iterators shouldn't crash when a key is implicitly removed
+ dict, objects = self.make_weak_keyed_dict()
+ self.check_weak_destroy_while_iterating(dict, objects, 'keys')
+ self.check_weak_destroy_while_iterating(dict, objects, 'items')
+ self.check_weak_destroy_while_iterating(dict, objects, 'values')
+ self.check_weak_destroy_while_iterating(dict, objects, 'keyrefs')
+ dict, objects = self.make_weak_keyed_dict()
+ @contextlib.contextmanager
+ def testcontext():
+ try:
+ it = iter(dict.items())
+ next(it)
+ # Schedule a key/value for removal and recreate it
+ v = objects.pop().arg
+ gc.collect() # just in case
+ yield Object(v), v
+ finally:
+ it = None # should commit all removals
+ self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
+
+ def test_weak_values_destroy_while_iterating(self):
+ # Issue #7105: iterators shouldn't crash when a key is implicitly removed
+ dict, objects = self.make_weak_valued_dict()
+ self.check_weak_destroy_while_iterating(dict, objects, 'keys')
+ self.check_weak_destroy_while_iterating(dict, objects, 'items')
+ self.check_weak_destroy_while_iterating(dict, objects, 'values')
+ self.check_weak_destroy_while_iterating(dict, objects, 'itervaluerefs')
+ self.check_weak_destroy_while_iterating(dict, objects, 'valuerefs')
+ dict, objects = self.make_weak_valued_dict()
+ @contextlib.contextmanager
+ def testcontext():
+ try:
+ it = iter(dict.items())
+ next(it)
+ # Schedule a key/value for removal and recreate it
+ k = objects.pop().arg
+ gc.collect() # just in case
+ yield k, Object(k)
+ finally:
+ it = None # should commit all removals
+ self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
+
def test_make_weak_keyed_dict_from_dict(self):
o = Object(3)
dict = weakref.WeakKeyDictionary({o:364})
diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py
index 651efe2..4e0aa38 100644
--- a/Lib/test/test_weakset.py
+++ b/Lib/test/test_weakset.py
@@ -10,6 +10,8 @@ import sys
import warnings
import collections
from collections import UserString as ustr
+import gc
+import contextlib
class Foo:
@@ -307,6 +309,54 @@ class TestWeakSet(unittest.TestCase):
self.assertFalse(self.s == WeakSet([Foo]))
self.assertFalse(self.s == 1)
+ def test_weak_destroy_while_iterating(self):
+ # Issue #7105: iterators shouldn't crash when a key is implicitly removed
+ # Create new items to be sure no-one else holds a reference
+ items = [ustr(c) for c in ('a', 'b', 'c')]
+ s = WeakSet(items)
+ it = iter(s)
+ next(it) # Trigger internal iteration
+ # Destroy an item
+ del items[-1]
+ gc.collect() # just in case
+ # We have removed either the first consumed items, or another one
+ self.assertIn(len(list(it)), [len(items), len(items) - 1])
+ del it
+ # The removal has been committed
+ self.assertEqual(len(s), len(items))
+
+ def test_weak_destroy_and_mutate_while_iterating(self):
+ # Issue #7105: iterators shouldn't crash when a key is implicitly removed
+ items = [ustr(c) for c in string.ascii_letters]
+ s = WeakSet(items)
+ @contextlib.contextmanager
+ def testcontext():
+ try:
+ it = iter(s)
+ next(it)
+ # Schedule an item for removal and recreate it
+ u = ustr(str(items.pop()))
+ gc.collect() # just in case
+ yield u
+ finally:
+ it = None # should commit all removals
+
+ with testcontext() as u:
+ self.assertFalse(u in s)
+ with testcontext() as u:
+ self.assertRaises(KeyError, s.remove, u)
+ self.assertFalse(u in s)
+ with testcontext() as u:
+ s.add(u)
+ self.assertTrue(u in s)
+ t = s.copy()
+ with testcontext() as u:
+ s.update(t)
+ self.assertEqual(len(s), len(t))
+ with testcontext() as u:
+ s.clear()
+ self.assertEqual(len(s), 0)
+
def test_main(verbose=None):
support.run_unittest(TestWeakSet)
diff --git a/Lib/weakref.py b/Lib/weakref.py
index 5e6cc8b..66c4dc6 100644
--- a/Lib/weakref.py
+++ b/Lib/weakref.py
@@ -18,7 +18,7 @@ from _weakref import (
ProxyType,
ReferenceType)
-from _weakrefset import WeakSet
+from _weakrefset import WeakSet, _IterationGuard
import collections # Import after _weakref to avoid circular import.
@@ -46,11 +46,25 @@ class WeakValueDictionary(collections.MutableMapping):
def remove(wr, selfref=ref(self)):
self = selfref()
if self is not None:
- del self.data[wr.key]
+ if self._iterating:
+ self._pending_removals.append(wr.key)
+ else:
+ del self.data[wr.key]
self._remove = remove
+ # A list of keys to be removed
+ self._pending_removals = []
+ self._iterating = set()
self.data = d = {}
self.update(*args, **kw)
+ def _commit_removals(self):
+ l = self._pending_removals
+ d = self.data
+ # We shouldn't encounter any KeyError, because this method should
+ # always be called *before* mutating the dict.
+ while l:
+ del d[l.pop()]
+
def __getitem__(self, key):
o = self.data[key]()
if o is None:
@@ -59,6 +73,8 @@ class WeakValueDictionary(collections.MutableMapping):
return o
def __delitem__(self, key):
+ if self._pending_removals:
+ self._commit_removals()
del self.data[key]
def __len__(self):
@@ -75,6 +91,8 @@ class WeakValueDictionary(collections.MutableMapping):
return "<WeakValueDictionary at %s>" % id(self)
def __setitem__(self, key, value):
+ if self._pending_removals:
+ self._commit_removals()
self.data[key] = KeyedRef(value, self._remove, key)
def copy(self):
@@ -110,24 +128,19 @@ class WeakValueDictionary(collections.MutableMapping):
return o
def items(self):
- L = []
- for key, wr in self.data.items():
- o = wr()
- if o is not None:
- L.append((key, o))
- return L
-
- def items(self):
- for wr in self.data.values():
- value = wr()
- if value is not None:
- yield wr.key, value
+ with _IterationGuard(self):
+ for k, wr in self.data.items():
+ v = wr()
+ if v is not None:
+ yield k, v
def keys(self):
- return iter(self.data.keys())
+ with _IterationGuard(self):
+ for k, wr in self.data.items():
+ if wr() is not None:
+ yield k
- def __iter__(self):
- return iter(self.data.keys())
+ __iter__ = keys
def itervaluerefs(self):
"""Return an iterator that yields the weak references to the values.
@@ -139,15 +152,20 @@ class WeakValueDictionary(collections.MutableMapping):
keep the values around longer than needed.
"""
- return self.data.values()
+ with _IterationGuard(self):
+ for wr in self.data.values():
+ yield wr
def values(self):
- for wr in self.data.values():
- obj = wr()
- if obj is not None:
- yield obj
+ with _IterationGuard(self):
+ for wr in self.data.values():
+ obj = wr()
+ if obj is not None:
+ yield obj
def popitem(self):
+ if self._pending_removals:
+ self._commit_removals()
while 1:
key, wr = self.data.popitem()
o = wr()
@@ -155,6 +173,8 @@ class WeakValueDictionary(collections.MutableMapping):
return key, o
def pop(self, key, *args):
+ if self._pending_removals:
+ self._commit_removals()
try:
o = self.data.pop(key)()
except KeyError:
@@ -170,12 +190,16 @@ class WeakValueDictionary(collections.MutableMapping):
try:
wr = self.data[key]
except KeyError:
+ if self._pending_removals:
+ self._commit_removals()
self.data[key] = KeyedRef(default, self._remove, key)
return default
else:
return wr()
def update(self, dict=None, **kwargs):
+ if self._pending_removals:
+ self._commit_removals()
d = self.data
if dict is not None:
if not hasattr(dict, "items"):
@@ -195,7 +219,7 @@ class WeakValueDictionary(collections.MutableMapping):
keep the values around longer than needed.
"""
- return self.data.values()
+ return list(self.data.values())
class KeyedRef(ref):
@@ -235,9 +259,29 @@ class WeakKeyDictionary(collections.MutableMapping):
def remove(k, selfref=ref(self)):
self = selfref()
if self is not None:
- del self.data[k]
+ if self._iterating:
+ self._pending_removals.append(k)
+ else:
+ del self.data[k]
self._remove = remove
- if dict is not None: self.update(dict)
+ # A list of dead weakrefs (keys to be removed)
+ self._pending_removals = []
+ self._iterating = set()
+ if dict is not None:
+ self.update(dict)
+
+ def _commit_removals(self):
+ # NOTE: We don't need to call this method before mutating the dict,
+ # because a dead weakref never compares equal to a live weakref,
+ # even if they happened to refer to equal objects.
+ # However, it means keys may already have been removed.
+ l = self._pending_removals
+ d = self.data
+ while l:
+ try:
+ del d[l.pop()]
+ except KeyError:
+ pass
def __delitem__(self, key):
del self.data[ref(key)]
@@ -284,34 +328,26 @@ class WeakKeyDictionary(collections.MutableMapping):
return wr in self.data
def items(self):
- for wr, value in self.data.items():
- key = wr()
- if key is not None:
- yield key, value
-
- def keyrefs(self):
- """Return an iterator that yields the weak references to the keys.
-
- The references are not guaranteed to be 'live' at the time
- they are used, so the result of calling the references needs
- to be checked before being used. This can be used to avoid
- creating references that will cause the garbage collector to
- keep the keys around longer than needed.
-
- """
- return self.data.keys()
+ with _IterationGuard(self):
+ for wr, value in self.data.items():
+ key = wr()
+ if key is not None:
+ yield key, value
def keys(self):
- for wr in self.data.keys():
- obj = wr()
- if obj is not None:
- yield obj
+ with _IterationGuard(self):
+ for wr in self.data:
+ obj = wr()
+ if obj is not None:
+ yield obj
- def __iter__(self):
- return iter(self.keys())
+ __iter__ = keys
def values(self):
- return iter(self.data.values())
+ with _IterationGuard(self):
+ for wr, value in self.data.items():
+ if wr() is not None:
+ yield value
def keyrefs(self):
"""Return a list of weak references to the keys.
@@ -323,7 +359,7 @@ class WeakKeyDictionary(collections.MutableMapping):
keep the keys around longer than needed.
"""
- return self.data.keys()
+ return list(self.data)
def popitem(self):
while 1: