diff options
author | Antoine Pitrou <solipsis@pitrou.net> | 2012-03-04 19:47:05 (GMT) |
---|---|---|
committer | Antoine Pitrou <solipsis@pitrou.net> | 2012-03-04 19:47:05 (GMT) |
commit | 94c2d6df544abd9eb0601cb5387774599cf0cdf1 (patch) | |
tree | 5b9d3679f52705913e5919ef2309ba66cd313c02 /Lib | |
parent | 859416e980dd6b4b96ad066e723709928c5afa81 (diff) | |
download | cpython-94c2d6df544abd9eb0601cb5387774599cf0cdf1.zip cpython-94c2d6df544abd9eb0601cb5387774599cf0cdf1.tar.gz cpython-94c2d6df544abd9eb0601cb5387774599cf0cdf1.tar.bz2 |
Fix some set algebra methods of WeakSet objects.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/_weakrefset.py | 41 | ||||
-rw-r--r-- | Lib/test/test_weakset.py | 22 |
2 files changed, 30 insertions, 33 deletions
diff --git a/Lib/_weakrefset.py b/Lib/_weakrefset.py index b8d8043..ff613e6 100644 --- a/Lib/_weakrefset.py +++ b/Lib/_weakrefset.py @@ -123,26 +123,14 @@ class WeakSet(object): self.update(other) return self - # Helper functions for simple delegating methods. - def _apply(self, other, method): - if not isinstance(other, self.__class__): - other = self.__class__(other) - newdata = method(other.data) - newset = self.__class__() - newset.data = newdata - return newset - def difference(self, other): - return self._apply(other, self.data.difference) + newset = self.copy() + newset.difference_update(other) + return newset __sub__ = difference def difference_update(self, other): - if self._pending_removals: - self._commit_removals() - if self is other: - self.data.clear() - else: - self.data.difference_update(ref(item) for item in other) + self.__isub__(other) def __isub__(self, other): if self._pending_removals: self._commit_removals() @@ -153,13 +141,11 @@ class WeakSet(object): return self def intersection(self, other): - return self._apply(other, self.data.intersection) + return self.__class__(item for item in other if item in self) __and__ = intersection def intersection_update(self, other): - if self._pending_removals: - self._commit_removals() - self.data.intersection_update(ref(item) for item in other) + self.__iand__(other) def __iand__(self, other): if self._pending_removals: self._commit_removals() @@ -186,27 +172,24 @@ class WeakSet(object): return self.data == set(ref(item) for item in other) def symmetric_difference(self, other): - return self._apply(other, self.data.symmetric_difference) + newset = self.copy() + newset.symmetric_difference_update(other) + return newset __xor__ = symmetric_difference def symmetric_difference_update(self, other): - if self._pending_removals: - self._commit_removals() - if self is other: - self.data.clear() - else: - self.data.symmetric_difference_update(ref(item) for item in other) + self.__ixor__(other) def __ixor__(self, other): if self._pending_removals: self._commit_removals() if self is other: self.data.clear() else: - self.data.symmetric_difference_update(ref(item) for item in other) + self.data.symmetric_difference_update(ref(item, self._remove) for item in other) return self def union(self, other): - return self._apply(other, self.data.union) + return self.__class__(e for s in (self, other) for e in s) __or__ = union def isdisjoint(self, other): diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py index f981bdd..1f82a7d 100644 --- a/Lib/test/test_weakset.py +++ b/Lib/test/test_weakset.py @@ -83,6 +83,11 @@ class TestWeakSet(unittest.TestCase): x = WeakSet(self.items + self.items2) c = C(self.items2) self.assertEqual(self.s.union(c), x) + del c + self.assertEqual(len(u), len(self.items) + len(self.items2)) + self.items2.pop() + gc.collect() + self.assertEqual(len(u), len(self.items) + len(self.items2)) def test_or(self): i = self.s.union(self.items2) @@ -90,14 +95,19 @@ class TestWeakSet(unittest.TestCase): self.assertEqual(self.s | frozenset(self.items2), i) def test_intersection(self): - i = self.s.intersection(self.items2) + s = WeakSet(self.letters) + i = s.intersection(self.items2) for c in self.letters: - self.assertEqual(c in i, c in self.d and c in self.items2) - self.assertEqual(self.s, WeakSet(self.items)) + self.assertEqual(c in i, c in self.items2 and c in self.letters) + self.assertEqual(s, WeakSet(self.letters)) self.assertEqual(type(i), WeakSet) for C in set, frozenset, dict.fromkeys, list, tuple: x = WeakSet([]) - self.assertEqual(self.s.intersection(C(self.items2)), x) + self.assertEqual(i.intersection(C(self.items)), x) + self.assertEqual(len(i), len(self.items2)) + self.items2.pop() + gc.collect() + self.assertEqual(len(i), len(self.items2)) def test_isdisjoint(self): self.assertTrue(self.s.isdisjoint(WeakSet(self.items2))) @@ -128,6 +138,10 @@ class TestWeakSet(unittest.TestCase): self.assertEqual(self.s, WeakSet(self.items)) self.assertEqual(type(i), WeakSet) self.assertRaises(TypeError, self.s.symmetric_difference, [[]]) + self.assertEqual(len(i), len(self.items) + len(self.items2)) + self.items2.pop() + gc.collect() + self.assertEqual(len(i), len(self.items) + len(self.items2)) def test_xor(self): i = self.s.symmetric_difference(self.items2) |