diff options
author | Raymond Hettinger <python@rcn.com> | 2003-08-17 08:34:09 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2003-08-17 08:34:09 (GMT) |
commit | 6a1801271aa4011cf26e7a64b52f6be10997f267 (patch) | |
tree | 3aaa9dc996dcfd676d809bf585cd498420b3bcbf /Lib | |
parent | 236ffba40005039cfbc0bd7475345ef7fedf86c5 (diff) | |
download | cpython-6a1801271aa4011cf26e7a64b52f6be10997f267.zip cpython-6a1801271aa4011cf26e7a64b52f6be10997f267.tar.gz cpython-6a1801271aa4011cf26e7a64b52f6be10997f267.tar.bz2 |
Improvements to set.py:
* Relaxed the argument restrictions for non-operator methods. They now
allow any iterable instead of requiring a set. This makes the module
a little easier to use and paves the way for an efficient C
implementation which can take better advantage of iterable arguments
while screening out immutables.
* Deprecated Set.update() because it now duplicates Set.union_update()
* Adapted the tests and docs to include the above changes.
* Added more test coverage including testing identities and checking
to make sure non-restartable generators work as arguments.
Will backport to Py2.3.1 so that the interface remains consistent
across versions. The deprecation of update() will be changed to
a FutureWarning.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/sets.py | 95 | ||||
-rw-r--r-- | Lib/test/test_sets.py | 144 |
2 files changed, 190 insertions, 49 deletions
diff --git a/Lib/sets.py b/Lib/sets.py index 32eb0aa..3fe8c8f 100644 --- a/Lib/sets.py +++ b/Lib/sets.py @@ -196,17 +196,16 @@ class BaseSet(object): """ if not isinstance(other, BaseSet): return NotImplemented - result = self.__class__() - result._data = self._data.copy() - result._data.update(other._data) - return result + return self.union(other) def union(self, other): """Return the union of two sets as a new set. (I.e. all elements that are in either set.) """ - return self | other + result = self.__class__(self) + result._update(other) + return result def __and__(self, other): """Return the intersection of two sets as a new set. @@ -215,19 +214,21 @@ class BaseSet(object): """ if not isinstance(other, BaseSet): return NotImplemented - if len(self) <= len(other): - little, big = self, other - else: - little, big = other, self - common = ifilter(big._data.has_key, little) - return self.__class__(common) + return self.intersection(other) def intersection(self, other): """Return the intersection of two sets as a new set. (I.e. all elements that are in both sets.) """ - return self & other + if not isinstance(other, BaseSet): + other = Set(other) + if len(self) <= len(other): + little, big = self, other + else: + little, big = other, self + common = ifilter(big._data.has_key, little) + return self.__class__(common) def __xor__(self, other): """Return the symmetric difference of two sets as a new set. @@ -236,24 +237,27 @@ class BaseSet(object): """ if not isinstance(other, BaseSet): return NotImplemented + return self.symmetric_difference(other) + + def symmetric_difference(self, other): + """Return the symmetric difference of two sets as a new set. + + (I.e. all elements that are in exactly one of the sets.) + """ result = self.__class__() data = result._data value = True selfdata = self._data - otherdata = other._data + try: + otherdata = other._data + except AttributeError: + otherdata = Set(other)._data for elt in ifilterfalse(otherdata.has_key, selfdata): data[elt] = value for elt in ifilterfalse(selfdata.has_key, otherdata): data[elt] = value return result - def symmetric_difference(self, other): - """Return the symmetric difference of two sets as a new set. - - (I.e. all elements that are in exactly one of the sets.) - """ - return self ^ other - def __sub__(self, other): """Return the difference of two sets as a new Set. @@ -261,19 +265,23 @@ class BaseSet(object): """ if not isinstance(other, BaseSet): return NotImplemented - result = self.__class__() - data = result._data - value = True - for elt in ifilterfalse(other._data.has_key, self): - data[elt] = value - return result + return self.difference(other) def difference(self, other): """Return the difference of two sets as a new Set. (I.e. all elements that are in this set and not in the other.) """ - return self - other + result = self.__class__() + data = result._data + try: + otherdata = other._data + except AttributeError: + otherdata = Set(other)._data + value = True + for elt in ifilterfalse(otherdata.has_key, self): + data[elt] = value + return result # Membership test @@ -441,7 +449,7 @@ class Set(BaseSet): def union_update(self, other): """Update a set with the union of itself and another.""" - self |= other + self._update(other) def __iand__(self, other): """Update a set with the intersection of itself and another.""" @@ -451,40 +459,51 @@ class Set(BaseSet): def intersection_update(self, other): """Update a set with the intersection of itself and another.""" - self &= other + if isinstance(other, BaseSet): + self &= other + else: + self._data = (self.intersection(other))._data def __ixor__(self, other): """Update a set with the symmetric difference of itself and another.""" self._binary_sanity_check(other) + self.symmetric_difference_update(other) + return self + + def symmetric_difference_update(self, other): + """Update a set with the symmetric difference of itself and another.""" data = self._data value = True + if not isinstance(other, BaseSet): + other = Set(other) for elt in other: if elt in data: del data[elt] else: data[elt] = value - return self - - def symmetric_difference_update(self, other): - """Update a set with the symmetric difference of itself and another.""" - self ^= other def __isub__(self, other): """Remove all elements of another set from this set.""" self._binary_sanity_check(other) - data = self._data - for elt in ifilter(data.has_key, other): - del data[elt] + self.difference_update(other) return self def difference_update(self, other): """Remove all elements of another set from this set.""" - self -= other + data = self._data + if not isinstance(other, BaseSet): + other = Set(other) + for elt in ifilter(data.has_key, other): + del data[elt] # Python dict-like mass mutations: update, clear def update(self, iterable): """Add all values from an iterable (such as a list or file).""" + import warnings + warnings.warn("The update() method deprecated; " + "Use union_update() instead", + DeprecationWarning) self._update(iterable) def clear(self): diff --git a/Lib/test/test_sets.py b/Lib/test/test_sets.py index 9051d49..f0a1925 100644 --- a/Lib/test/test_sets.py +++ b/Lib/test/test_sets.py @@ -152,7 +152,7 @@ class TestExceptionPropagation(unittest.TestCase): self.assertRaises(TypeError, Set, baditer()) def test_instancesWithoutException(self): - """All of these iterables should load without exception.""" + # All of these iterables should load without exception. Set([1,2,3]) Set((1,2,3)) Set({'one':1, 'two':2, 'three':3}) @@ -392,15 +392,15 @@ class TestMutate(unittest.TestCase): self.failUnless(v in popped) def test_update_empty_tuple(self): - self.set.update(()) + self.set.union_update(()) self.assertEqual(self.set, Set(self.values)) def test_update_unit_tuple_overlap(self): - self.set.update(("a",)) + self.set.union_update(("a",)) self.assertEqual(self.set, Set(self.values)) def test_update_unit_tuple_non_overlap(self): - self.set.update(("a", "z")) + self.set.union_update(("a", "z")) self.assertEqual(self.set, Set(self.values + ["z"])) #============================================================================== @@ -503,7 +503,7 @@ class TestOnlySetsInBinaryOps(unittest.TestCase): self.assertRaises(TypeError, lambda: self.other > self.set) self.assertRaises(TypeError, lambda: self.other >= self.set) - def test_union_update(self): + def test_union_update_operator(self): try: self.set |= self.other except TypeError: @@ -511,11 +511,21 @@ class TestOnlySetsInBinaryOps(unittest.TestCase): else: self.fail("expected TypeError") + def test_union_update(self): + if self.otherIsIterable: + self.set.union_update(self.other) + else: + self.assertRaises(TypeError, self.set.union_update, self.other) + def test_union(self): self.assertRaises(TypeError, lambda: self.set | self.other) self.assertRaises(TypeError, lambda: self.other | self.set) + if self.otherIsIterable: + self.set.union(self.other) + else: + self.assertRaises(TypeError, self.set.union, self.other) - def test_intersection_update(self): + def test_intersection_update_operator(self): try: self.set &= self.other except TypeError: @@ -523,11 +533,23 @@ class TestOnlySetsInBinaryOps(unittest.TestCase): else: self.fail("expected TypeError") + def test_intersection_update(self): + if self.otherIsIterable: + self.set.intersection_update(self.other) + else: + self.assertRaises(TypeError, + self.set.intersection_update, + self.other) + def test_intersection(self): self.assertRaises(TypeError, lambda: self.set & self.other) self.assertRaises(TypeError, lambda: self.other & self.set) + if self.otherIsIterable: + self.set.intersection(self.other) + else: + self.assertRaises(TypeError, self.set.intersection, self.other) - def test_sym_difference_update(self): + def test_sym_difference_update_operator(self): try: self.set ^= self.other except TypeError: @@ -535,11 +557,23 @@ class TestOnlySetsInBinaryOps(unittest.TestCase): else: self.fail("expected TypeError") + def test_sym_difference_update(self): + if self.otherIsIterable: + self.set.symmetric_difference_update(self.other) + else: + self.assertRaises(TypeError, + self.set.symmetric_difference_update, + self.other) + def test_sym_difference(self): self.assertRaises(TypeError, lambda: self.set ^ self.other) self.assertRaises(TypeError, lambda: self.other ^ self.set) + if self.otherIsIterable: + self.set.symmetric_difference(self.other) + else: + self.assertRaises(TypeError, self.set.symmetric_difference, self.other) - def test_difference_update(self): + def test_difference_update_operator(self): try: self.set -= self.other except TypeError: @@ -547,16 +581,28 @@ class TestOnlySetsInBinaryOps(unittest.TestCase): else: self.fail("expected TypeError") + def test_difference_update(self): + if self.otherIsIterable: + self.set.difference_update(self.other) + else: + self.assertRaises(TypeError, + self.set.difference_update, + self.other) + def test_difference(self): self.assertRaises(TypeError, lambda: self.set - self.other) self.assertRaises(TypeError, lambda: self.other - self.set) - + if self.otherIsIterable: + self.set.difference(self.other) + else: + self.assertRaises(TypeError, self.set.difference, self.other) #------------------------------------------------------------------------------ class TestOnlySetsNumeric(TestOnlySetsInBinaryOps): def setUp(self): self.set = Set((1, 2, 3)) self.other = 19 + self.otherIsIterable = False #------------------------------------------------------------------------------ @@ -564,6 +610,7 @@ class TestOnlySetsDict(TestOnlySetsInBinaryOps): def setUp(self): self.set = Set((1, 2, 3)) self.other = {1:2, 3:4} + self.otherIsIterable = True #------------------------------------------------------------------------------ @@ -571,6 +618,34 @@ class TestOnlySetsOperator(TestOnlySetsInBinaryOps): def setUp(self): self.set = Set((1, 2, 3)) self.other = operator.add + self.otherIsIterable = False + +#------------------------------------------------------------------------------ + +class TestOnlySetsTuple(TestOnlySetsInBinaryOps): + def setUp(self): + self.set = Set((1, 2, 3)) + self.other = (2, 4, 6) + self.otherIsIterable = True + +#------------------------------------------------------------------------------ + +class TestOnlySetsString(TestOnlySetsInBinaryOps): + def setUp(self): + self.set = Set((1, 2, 3)) + self.other = 'abc' + self.otherIsIterable = True + +#------------------------------------------------------------------------------ + +class TestOnlySetsGenerator(TestOnlySetsInBinaryOps): + def setUp(self): + def gen(): + for i in xrange(0, 10, 2): + yield i + self.set = Set((1, 2, 3)) + self.other = gen() + self.otherIsIterable = True #============================================================================== @@ -625,6 +700,49 @@ class TestCopyingNested(TestCopying): #============================================================================== +class TestIdentities(unittest.TestCase): + def setUp(self): + self.a = Set('abracadabra') + self.b = Set('alacazam') + + def test_binopsVsSubsets(self): + a, b = self.a, self.b + self.assert_(a - b < a) + self.assert_(b - a < b) + self.assert_(a & b < a) + self.assert_(a & b < b) + self.assert_(a | b > a) + self.assert_(a | b > b) + self.assert_(a ^ b < a | b) + + def test_commutativity(self): + a, b = self.a, self.b + self.assertEqual(a&b, b&a) + self.assertEqual(a|b, b|a) + self.assertEqual(a^b, b^a) + if a != b: + self.assertNotEqual(a-b, b-a) + + def test_summations(self): + # check that sums of parts equal the whole + a, b = self.a, self.b + self.assertEqual((a-b)|(a&b)|(b-a), a|b) + self.assertEqual((a&b)|(a^b), a|b) + self.assertEqual(a|(b-a), a|b) + self.assertEqual((a-b)|b, a|b) + self.assertEqual((a-b)|(a&b), a) + self.assertEqual((b-a)|(a&b), b) + self.assertEqual((a-b)|(b-a), a^b) + + def test_exclusion(self): + # check that inverse operations show non-overlap + a, b, zero = self.a, self.b, Set() + self.assertEqual((a-b)&b, zero) + self.assertEqual((b-a)&a, zero) + self.assertEqual((a&b)&(a^b), zero) + +#============================================================================== + libreftest = """ Example from the Library Reference: Doc/lib/libsets.tex @@ -643,7 +761,7 @@ Example from the Library Reference: Doc/lib/libsets.tex Set(['Jack', 'Jane', 'Janice', 'John', 'Marvin']) >>> employees.issuperset(engineers) # superset test False ->>> employees.update(engineers) # update from another set +>>> employees.union_update(engineers) # update from another set >>> employees.issuperset(engineers) True >>> for group in [engineers, programmers, managers, employees]: @@ -680,11 +798,15 @@ def test_main(verbose=None): TestOnlySetsNumeric, TestOnlySetsDict, TestOnlySetsOperator, + TestOnlySetsTuple, + TestOnlySetsString, + TestOnlySetsGenerator, TestCopyingEmpty, TestCopyingSingleton, TestCopyingTriple, TestCopyingTuple, - TestCopyingNested + TestCopyingNested, + TestIdentities, ) test_support.run_doctest(test_sets, verbose) |