diff options
author | Antoine Pitrou <solipsis@pitrou.net> | 2012-03-04 19:47:05 (GMT) |
---|---|---|
committer | Antoine Pitrou <solipsis@pitrou.net> | 2012-03-04 19:47:05 (GMT) |
commit | 9c47ac05d13971d1e6ee2d74afad8d1b57e5b2ac (patch) | |
tree | a023e3782182c4b7578876c7a95a3d144e04cf4d | |
parent | de89d4b09758a1c94dd97be554c967d52759228a (diff) | |
download | cpython-9c47ac05d13971d1e6ee2d74afad8d1b57e5b2ac.zip cpython-9c47ac05d13971d1e6ee2d74afad8d1b57e5b2ac.tar.gz cpython-9c47ac05d13971d1e6ee2d74afad8d1b57e5b2ac.tar.bz2 |
Fix some set algebra methods of WeakSet objects.
-rw-r--r-- | Lib/_weakrefset.py | 41 | ||||
-rw-r--r-- | Lib/test/test_weakset.py | 22 |
2 files changed, 30 insertions, 33 deletions
diff --git a/Lib/_weakrefset.py b/Lib/_weakrefset.py index c2717e7..1f41841 100644 --- a/Lib/_weakrefset.py +++ b/Lib/_weakrefset.py @@ -121,26 +121,14 @@ class WeakSet: self.update(other) return self - # Helper functions for simple delegating methods. - def _apply(self, other, method): - if not isinstance(other, self.__class__): - other = self.__class__(other) - newdata = method(other.data) - newset = self.__class__() - newset.data = newdata - return newset - def difference(self, other): - return self._apply(other, self.data.difference) + newset = self.copy() + newset.difference_update(other) + return newset __sub__ = difference def difference_update(self, other): - if self._pending_removals: - self._commit_removals() - if self is other: - self.data.clear() - else: - self.data.difference_update(ref(item) for item in other) + self.__isub__(other) def __isub__(self, other): if self._pending_removals: self._commit_removals() @@ -151,13 +139,11 @@ class WeakSet: return self def intersection(self, other): - return self._apply(other, self.data.intersection) + return self.__class__(item for item in other if item in self) __and__ = intersection def intersection_update(self, other): - if self._pending_removals: - self._commit_removals() - self.data.intersection_update(ref(item) for item in other) + self.__iand__(other) def __iand__(self, other): if self._pending_removals: self._commit_removals() @@ -184,27 +170,24 @@ class WeakSet: return self.data == set(ref(item) for item in other) def symmetric_difference(self, other): - return self._apply(other, self.data.symmetric_difference) + newset = self.copy() + newset.symmetric_difference_update(other) + return newset __xor__ = symmetric_difference def symmetric_difference_update(self, other): - if self._pending_removals: - self._commit_removals() - if self is other: - self.data.clear() - else: - self.data.symmetric_difference_update(ref(item) for item in other) + self.__ixor__(other) def __ixor__(self, other): if self._pending_removals: self._commit_removals() if self is other: self.data.clear() else: - self.data.symmetric_difference_update(ref(item) for item in other) + self.data.symmetric_difference_update(ref(item, self._remove) for item in other) return self def union(self, other): - return self._apply(other, self.data.union) + return self.__class__(e for s in (self, other) for e in s) __or__ = union def isdisjoint(self, other): diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py index 35db7a6..3c71f62 100644 --- a/Lib/test/test_weakset.py +++ b/Lib/test/test_weakset.py @@ -71,6 +71,11 @@ class TestWeakSet(unittest.TestCase): x = WeakSet(self.items + self.items2) c = C(self.items2) self.assertEqual(self.s.union(c), x) + del c + self.assertEqual(len(u), len(self.items) + len(self.items2)) + self.items2.pop() + gc.collect() + self.assertEqual(len(u), len(self.items) + len(self.items2)) def test_or(self): i = self.s.union(self.items2) @@ -78,14 +83,19 @@ class TestWeakSet(unittest.TestCase): self.assertEqual(self.s | frozenset(self.items2), i) def test_intersection(self): - i = self.s.intersection(self.items2) + s = WeakSet(self.letters) + i = s.intersection(self.items2) for c in self.letters: - self.assertEqual(c in i, c in self.d and c in self.items2) - self.assertEqual(self.s, WeakSet(self.items)) + self.assertEqual(c in i, c in self.items2 and c in self.letters) + self.assertEqual(s, WeakSet(self.letters)) self.assertEqual(type(i), WeakSet) for C in set, frozenset, dict.fromkeys, list, tuple: x = WeakSet([]) - self.assertEqual(self.s.intersection(C(self.items2)), x) + self.assertEqual(i.intersection(C(self.items)), x) + self.assertEqual(len(i), len(self.items2)) + self.items2.pop() + gc.collect() + self.assertEqual(len(i), len(self.items2)) def test_isdisjoint(self): self.assertTrue(self.s.isdisjoint(WeakSet(self.items2))) @@ -116,6 +126,10 @@ class TestWeakSet(unittest.TestCase): self.assertEqual(self.s, WeakSet(self.items)) self.assertEqual(type(i), WeakSet) self.assertRaises(TypeError, self.s.symmetric_difference, [[]]) + self.assertEqual(len(i), len(self.items) + len(self.items2)) + self.items2.pop() + gc.collect() + self.assertEqual(len(i), len(self.items) + len(self.items2)) def test_xor(self): i = self.s.symmetric_difference(self.items2) |