summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKumar Aditya <kumaraditya@python.org>2024-10-13 15:35:05 (GMT)
committerGitHub <noreply@github.com>2024-10-13 15:35:05 (GMT)
commitcd0f9d111a040ad863c680e9f464419640c8c3fd (patch)
tree8bb24147b6b0e69f406286ee456817fb50e9026a
parent08489325d1cd94eba97c5f5f8cac49521fd0b0d7 (diff)
downloadcpython-cd0f9d111a040ad863c680e9f464419640c8c3fd.zip
cpython-cd0f9d111a040ad863c680e9f464419640c8c3fd.tar.gz
cpython-cd0f9d111a040ad863c680e9f464419640c8c3fd.tar.bz2
gh-89967: make WeakKeyDictionary and WeakValueDictionary thread safe (#125325)
Make `WeakKeyDictionary` and `WeakValueDictionary` thread safe by copying the underlying the dict before iterating over it.
-rw-r--r--Lib/_weakrefset.py25
-rw-r--r--Lib/weakref.py198
-rw-r--r--Misc/NEWS.d/next/Library/2024-10-11-16-19-46.gh-issue-89967.vhWUOR.rst1
3 files changed, 50 insertions, 174 deletions
diff --git a/Lib/_weakrefset.py b/Lib/_weakrefset.py
index 2071755..d1c7fca 100644
--- a/Lib/_weakrefset.py
+++ b/Lib/_weakrefset.py
@@ -8,31 +8,6 @@ from types import GenericAlias
__all__ = ['WeakSet']
-class _IterationGuard:
- # This context manager registers itself in the current iterators of the
- # weak container, such as to delay all removals until the context manager
- # exits.
- # This technique should be relatively thread-safe (since sets are).
-
- def __init__(self, weakcontainer):
- # Don't create cycles
- self.weakcontainer = ref(weakcontainer)
-
- def __enter__(self):
- w = self.weakcontainer()
- if w is not None:
- w._iterating.add(self)
- return self
-
- def __exit__(self, e, t, b):
- w = self.weakcontainer()
- if w is not None:
- s = w._iterating
- s.remove(self)
- if not s:
- w._commit_removals()
-
-
class WeakSet:
def __init__(self, data=None):
self.data = set()
diff --git a/Lib/weakref.py b/Lib/weakref.py
index 25b7092..94e4278 100644
--- a/Lib/weakref.py
+++ b/Lib/weakref.py
@@ -19,7 +19,7 @@ from _weakref import (
ReferenceType,
_remove_dead_weakref)
-from _weakrefset import WeakSet, _IterationGuard
+from _weakrefset import WeakSet
import _collections_abc # Import after _weakref to avoid circular import.
import sys
@@ -105,34 +105,14 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):
self = selfref()
if self is not None:
- if self._iterating:
- self._pending_removals.append(wr.key)
- else:
- # Atomic removal is necessary since this function
- # can be called asynchronously by the GC
- _atomic_removal(self.data, wr.key)
+ # Atomic removal is necessary since this function
+ # can be called asynchronously by the GC
+ _atomic_removal(self.data, wr.key)
self._remove = remove
- # A list of keys to be removed
- self._pending_removals = []
- self._iterating = set()
self.data = {}
self.update(other, **kw)
- def _commit_removals(self, _atomic_removal=_remove_dead_weakref):
- pop = self._pending_removals.pop
- d = self.data
- # We shouldn't encounter any KeyError, because this method should
- # always be called *before* mutating the dict.
- while True:
- try:
- key = pop()
- except IndexError:
- return
- _atomic_removal(d, key)
-
def __getitem__(self, key):
- if self._pending_removals:
- self._commit_removals()
o = self.data[key]()
if o is None:
raise KeyError(key)
@@ -140,18 +120,12 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
return o
def __delitem__(self, key):
- if self._pending_removals:
- self._commit_removals()
del self.data[key]
def __len__(self):
- if self._pending_removals:
- self._commit_removals()
return len(self.data)
def __contains__(self, key):
- if self._pending_removals:
- self._commit_removals()
try:
o = self.data[key]()
except KeyError:
@@ -162,38 +136,28 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
return "<%s at %#x>" % (self.__class__.__name__, id(self))
def __setitem__(self, key, value):
- if self._pending_removals:
- self._commit_removals()
self.data[key] = KeyedRef(value, self._remove, key)
def copy(self):
- if self._pending_removals:
- self._commit_removals()
new = WeakValueDictionary()
- with _IterationGuard(self):
- for key, wr in self.data.items():
- o = wr()
- if o is not None:
- new[key] = o
+ for key, wr in self.data.copy().items():
+ o = wr()
+ if o is not None:
+ new[key] = o
return new
__copy__ = copy
def __deepcopy__(self, memo):
from copy import deepcopy
- if self._pending_removals:
- self._commit_removals()
new = self.__class__()
- with _IterationGuard(self):
- for key, wr in self.data.items():
- o = wr()
- if o is not None:
- new[deepcopy(key, memo)] = o
+ for key, wr in self.data.copy().items():
+ o = wr()
+ if o is not None:
+ new[deepcopy(key, memo)] = o
return new
def get(self, key, default=None):
- if self._pending_removals:
- self._commit_removals()
try:
wr = self.data[key]
except KeyError:
@@ -207,21 +171,15 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
return o
def items(self):
- if self._pending_removals:
- self._commit_removals()
- with _IterationGuard(self):
- for k, wr in self.data.items():
- v = wr()
- if v is not None:
- yield k, v
+ for k, wr in self.data.copy().items():
+ v = wr()
+ if v is not None:
+ yield k, v
def keys(self):
- if self._pending_removals:
- self._commit_removals()
- with _IterationGuard(self):
- for k, wr in self.data.items():
- if wr() is not None:
- yield k
+ for k, wr in self.data.copy().items():
+ if wr() is not None:
+ yield k
__iter__ = keys
@@ -235,23 +193,15 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
keep the values around longer than needed.
"""
- if self._pending_removals:
- self._commit_removals()
- with _IterationGuard(self):
- yield from self.data.values()
+ yield from self.data.copy().values()
def values(self):
- if self._pending_removals:
- self._commit_removals()
- with _IterationGuard(self):
- for wr in self.data.values():
- obj = wr()
- if obj is not None:
- yield obj
+ for wr in self.data.copy().values():
+ obj = wr()
+ if obj is not None:
+ yield obj
def popitem(self):
- if self._pending_removals:
- self._commit_removals()
while True:
key, wr = self.data.popitem()
o = wr()
@@ -259,8 +209,6 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
return key, o
def pop(self, key, *args):
- if self._pending_removals:
- self._commit_removals()
try:
o = self.data.pop(key)()
except KeyError:
@@ -279,16 +227,12 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
except KeyError:
o = None
if o is None:
- if self._pending_removals:
- self._commit_removals()
self.data[key] = KeyedRef(default, self._remove, key)
return default
else:
return o
def update(self, other=None, /, **kwargs):
- if self._pending_removals:
- self._commit_removals()
d = self.data
if other is not None:
if not hasattr(other, "items"):
@@ -308,9 +252,7 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
keep the values around longer than needed.
"""
- if self._pending_removals:
- self._commit_removals()
- return list(self.data.values())
+ return list(self.data.copy().values())
def __ior__(self, other):
self.update(other)
@@ -369,57 +311,22 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
def remove(k, selfref=ref(self)):
self = selfref()
if self is not None:
- if self._iterating:
- self._pending_removals.append(k)
- else:
- try:
- del self.data[k]
- except KeyError:
- pass
+ try:
+ del self.data[k]
+ except KeyError:
+ pass
self._remove = remove
- # A list of dead weakrefs (keys to be removed)
- self._pending_removals = []
- self._iterating = set()
- self._dirty_len = False
if dict is not None:
self.update(dict)
- def _commit_removals(self):
- # NOTE: We don't need to call this method before mutating the dict,
- # because a dead weakref never compares equal to a live weakref,
- # even if they happened to refer to equal objects.
- # However, it means keys may already have been removed.
- pop = self._pending_removals.pop
- d = self.data
- while True:
- try:
- key = pop()
- except IndexError:
- return
-
- try:
- del d[key]
- except KeyError:
- pass
-
- def _scrub_removals(self):
- d = self.data
- self._pending_removals = [k for k in self._pending_removals if k in d]
- self._dirty_len = False
-
def __delitem__(self, key):
- self._dirty_len = True
del self.data[ref(key)]
def __getitem__(self, key):
return self.data[ref(key)]
def __len__(self):
- if self._dirty_len and self._pending_removals:
- # self._pending_removals may still contain keys which were
- # explicitly removed, we have to scrub them (see issue #21173).
- self._scrub_removals()
- return len(self.data) - len(self._pending_removals)
+ return len(self.data)
def __repr__(self):
return "<%s at %#x>" % (self.__class__.__name__, id(self))
@@ -429,11 +336,10 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
def copy(self):
new = WeakKeyDictionary()
- with _IterationGuard(self):
- for key, value in self.data.items():
- o = key()
- if o is not None:
- new[o] = value
+ for key, value in self.data.copy().items():
+ o = key()
+ if o is not None:
+ new[o] = value
return new
__copy__ = copy
@@ -441,11 +347,10 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
def __deepcopy__(self, memo):
from copy import deepcopy
new = self.__class__()
- with _IterationGuard(self):
- for key, value in self.data.items():
- o = key()
- if o is not None:
- new[o] = deepcopy(value, memo)
+ for key, value in self.data.copy().items():
+ o = key()
+ if o is not None:
+ new[o] = deepcopy(value, memo)
return new
def get(self, key, default=None):
@@ -459,26 +364,23 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
return wr in self.data
def items(self):
- with _IterationGuard(self):
- for wr, value in self.data.items():
- key = wr()
- if key is not None:
- yield key, value
+ for wr, value in self.data.copy().items():
+ key = wr()
+ if key is not None:
+ yield key, value
def keys(self):
- with _IterationGuard(self):
- for wr in self.data:
- obj = wr()
- if obj is not None:
- yield obj
+ for wr in self.data.copy():
+ obj = wr()
+ if obj is not None:
+ yield obj
__iter__ = keys
def values(self):
- with _IterationGuard(self):
- for wr, value in self.data.items():
- if wr() is not None:
- yield value
+ for wr, value in self.data.copy().items():
+ if wr() is not None:
+ yield value
def keyrefs(self):
"""Return a list of weak references to the keys.
@@ -493,7 +395,6 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
return list(self.data)
def popitem(self):
- self._dirty_len = True
while True:
key, value = self.data.popitem()
o = key()
@@ -501,7 +402,6 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
return o, value
def pop(self, key, *args):
- self._dirty_len = True
return self.data.pop(ref(key), *args)
def setdefault(self, key, default=None):
diff --git a/Misc/NEWS.d/next/Library/2024-10-11-16-19-46.gh-issue-89967.vhWUOR.rst b/Misc/NEWS.d/next/Library/2024-10-11-16-19-46.gh-issue-89967.vhWUOR.rst
new file mode 100644
index 0000000..d086045
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2024-10-11-16-19-46.gh-issue-89967.vhWUOR.rst
@@ -0,0 +1 @@
+Make :class:`~weakref.WeakKeyDictionary` and :class:`~weakref.WeakValueDictionary` safe against concurrent mutations from other threads. Patch by Kumar Aditya.