From 6a1801271aa4011cf26e7a64b52f6be10997f267 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Sun, 17 Aug 2003 08:34:09 +0000 Subject: Improvements to set.py: * Relaxed the argument restrictions for non-operator methods. They now allow any iterable instead of requiring a set. This makes the module a little easier to use and paves the way for an efficient C implementation which can take better advantage of iterable arguments while screening out immutables. * Deprecated Set.update() because it now duplicates Set.union_update() * Adapted the tests and docs to include the above changes. * Added more test coverage including testing identities and checking to make sure non-restartable generators work as arguments. Will backport to Py2.3.1 so that the interface remains consistent across versions. The deprecation of update() will be changed to a FutureWarning. --- Doc/lib/libsets.tex | 22 ++++++-- Lib/sets.py | 95 ++++++++++++++++++++------------- Lib/test/test_sets.py | 144 ++++++++++++++++++++++++++++++++++++++++++++++---- Misc/NEWS | 5 +- 4 files changed, 213 insertions(+), 53 deletions(-) diff --git a/Doc/lib/libsets.tex b/Doc/lib/libsets.tex index 71b6d3d..8551ab6 100644 --- a/Doc/lib/libsets.tex +++ b/Doc/lib/libsets.tex @@ -91,6 +91,15 @@ the following operations: {new set with a shallow copy of \var{s}} \end{tableiii} +Note, this non-operator versions of \method{union()}, +\method{intersection()}, \method{difference()}, and +\method{symmetric_difference()} will accept any iterable as an argument. +In contrast, their operator based counterparts require their arguments to +be sets. This precludes error-prone constructions like +\code{Set('abc') \&\ 'cbs'} in favor of the more readable +\code{Set('abc').intersection('cbs')}. +\versionchanged[Formerly all arguments were required to be sets]{2.3.1} + In addition, both \class{Set} and \class{ImmutableSet} support set to set comparisons. Two sets are equal if and only if every element of each set is contained in the other (each is a subset @@ -145,12 +154,19 @@ but not found in \class{ImmutableSet}: \lineiii{\var{s}.pop()}{} {remove and return an arbitrary element from \var{s}; raises KeyError if empty} - \lineiii{\var{s}.update(\var{t})}{} - {add elements from \var{t} to set \var{s}} \lineiii{\var{s}.clear()}{} {remove all elements from set \var{s}} \end{tableiii} +\versionchanged[Earlier versions had an \method{update()} method; use + \method{union_update()} instead]{2.3.1} + +Note, this non-operator versions of \method{union_update()}, +\method{intersection_update()}, \method{difference_update()}, and +\method{symmetric_difference_update()} will accept any iterable as +an argument. +\versionchanged[Formerly all arguments were required to be sets]{2.3.1} + \subsection{Example \label{set-example}} @@ -167,7 +183,7 @@ but not found in \class{ImmutableSet}: Set(['Jane', 'Marvin', 'Janice', 'John', 'Jack']) >>> employees.issuperset(engineers) # superset test False ->>> employees.update(engineers) # update from another set +>>> employees.union_update(engineers) # update from another set >>> employees.issuperset(engineers) True >>> for group in [engineers, programmers, management, employees]: diff --git a/Lib/sets.py b/Lib/sets.py index 32eb0aa..3fe8c8f 100644 --- a/Lib/sets.py +++ b/Lib/sets.py @@ -196,17 +196,16 @@ class BaseSet(object): """ if not isinstance(other, BaseSet): return NotImplemented - result = self.__class__() - result._data = self._data.copy() - result._data.update(other._data) - return result + return self.union(other) def union(self, other): """Return the union of two sets as a new set. (I.e. all elements that are in either set.) """ - return self | other + result = self.__class__(self) + result._update(other) + return result def __and__(self, other): """Return the intersection of two sets as a new set. @@ -215,19 +214,21 @@ class BaseSet(object): """ if not isinstance(other, BaseSet): return NotImplemented - if len(self) <= len(other): - little, big = self, other - else: - little, big = other, self - common = ifilter(big._data.has_key, little) - return self.__class__(common) + return self.intersection(other) def intersection(self, other): """Return the intersection of two sets as a new set. (I.e. all elements that are in both sets.) """ - return self & other + if not isinstance(other, BaseSet): + other = Set(other) + if len(self) <= len(other): + little, big = self, other + else: + little, big = other, self + common = ifilter(big._data.has_key, little) + return self.__class__(common) def __xor__(self, other): """Return the symmetric difference of two sets as a new set. @@ -236,24 +237,27 @@ class BaseSet(object): """ if not isinstance(other, BaseSet): return NotImplemented + return self.symmetric_difference(other) + + def symmetric_difference(self, other): + """Return the symmetric difference of two sets as a new set. + + (I.e. all elements that are in exactly one of the sets.) + """ result = self.__class__() data = result._data value = True selfdata = self._data - otherdata = other._data + try: + otherdata = other._data + except AttributeError: + otherdata = Set(other)._data for elt in ifilterfalse(otherdata.has_key, selfdata): data[elt] = value for elt in ifilterfalse(selfdata.has_key, otherdata): data[elt] = value return result - def symmetric_difference(self, other): - """Return the symmetric difference of two sets as a new set. - - (I.e. all elements that are in exactly one of the sets.) - """ - return self ^ other - def __sub__(self, other): """Return the difference of two sets as a new Set. @@ -261,19 +265,23 @@ class BaseSet(object): """ if not isinstance(other, BaseSet): return NotImplemented - result = self.__class__() - data = result._data - value = True - for elt in ifilterfalse(other._data.has_key, self): - data[elt] = value - return result + return self.difference(other) def difference(self, other): """Return the difference of two sets as a new Set. (I.e. all elements that are in this set and not in the other.) """ - return self - other + result = self.__class__() + data = result._data + try: + otherdata = other._data + except AttributeError: + otherdata = Set(other)._data + value = True + for elt in ifilterfalse(otherdata.has_key, self): + data[elt] = value + return result # Membership test @@ -441,7 +449,7 @@ class Set(BaseSet): def union_update(self, other): """Update a set with the union of itself and another.""" - self |= other + self._update(other) def __iand__(self, other): """Update a set with the intersection of itself and another.""" @@ -451,40 +459,51 @@ class Set(BaseSet): def intersection_update(self, other): """Update a set with the intersection of itself and another.""" - self &= other + if isinstance(other, BaseSet): + self &= other + else: + self._data = (self.intersection(other))._data def __ixor__(self, other): """Update a set with the symmetric difference of itself and another.""" self._binary_sanity_check(other) + self.symmetric_difference_update(other) + return self + + def symmetric_difference_update(self, other): + """Update a set with the symmetric difference of itself and another.""" data = self._data value = True + if not isinstance(other, BaseSet): + other = Set(other) for elt in other: if elt in data: del data[elt] else: data[elt] = value - return self - - def symmetric_difference_update(self, other): - """Update a set with the symmetric difference of itself and another.""" - self ^= other def __isub__(self, other): """Remove all elements of another set from this set.""" self._binary_sanity_check(other) - data = self._data - for elt in ifilter(data.has_key, other): - del data[elt] + self.difference_update(other) return self def difference_update(self, other): """Remove all elements of another set from this set.""" - self -= other + data = self._data + if not isinstance(other, BaseSet): + other = Set(other) + for elt in ifilter(data.has_key, other): + del data[elt] # Python dict-like mass mutations: update, clear def update(self, iterable): """Add all values from an iterable (such as a list or file).""" + import warnings + warnings.warn("The update() method deprecated; " + "Use union_update() instead", + DeprecationWarning) self._update(iterable) def clear(self): diff --git a/Lib/test/test_sets.py b/Lib/test/test_sets.py index 9051d49..f0a1925 100644 --- a/Lib/test/test_sets.py +++ b/Lib/test/test_sets.py @@ -152,7 +152,7 @@ class TestExceptionPropagation(unittest.TestCase): self.assertRaises(TypeError, Set, baditer()) def test_instancesWithoutException(self): - """All of these iterables should load without exception.""" + # All of these iterables should load without exception. Set([1,2,3]) Set((1,2,3)) Set({'one':1, 'two':2, 'three':3}) @@ -392,15 +392,15 @@ class TestMutate(unittest.TestCase): self.failUnless(v in popped) def test_update_empty_tuple(self): - self.set.update(()) + self.set.union_update(()) self.assertEqual(self.set, Set(self.values)) def test_update_unit_tuple_overlap(self): - self.set.update(("a",)) + self.set.union_update(("a",)) self.assertEqual(self.set, Set(self.values)) def test_update_unit_tuple_non_overlap(self): - self.set.update(("a", "z")) + self.set.union_update(("a", "z")) self.assertEqual(self.set, Set(self.values + ["z"])) #============================================================================== @@ -503,7 +503,7 @@ class TestOnlySetsInBinaryOps(unittest.TestCase): self.assertRaises(TypeError, lambda: self.other > self.set) self.assertRaises(TypeError, lambda: self.other >= self.set) - def test_union_update(self): + def test_union_update_operator(self): try: self.set |= self.other except TypeError: @@ -511,11 +511,21 @@ class TestOnlySetsInBinaryOps(unittest.TestCase): else: self.fail("expected TypeError") + def test_union_update(self): + if self.otherIsIterable: + self.set.union_update(self.other) + else: + self.assertRaises(TypeError, self.set.union_update, self.other) + def test_union(self): self.assertRaises(TypeError, lambda: self.set | self.other) self.assertRaises(TypeError, lambda: self.other | self.set) + if self.otherIsIterable: + self.set.union(self.other) + else: + self.assertRaises(TypeError, self.set.union, self.other) - def test_intersection_update(self): + def test_intersection_update_operator(self): try: self.set &= self.other except TypeError: @@ -523,11 +533,23 @@ class TestOnlySetsInBinaryOps(unittest.TestCase): else: self.fail("expected TypeError") + def test_intersection_update(self): + if self.otherIsIterable: + self.set.intersection_update(self.other) + else: + self.assertRaises(TypeError, + self.set.intersection_update, + self.other) + def test_intersection(self): self.assertRaises(TypeError, lambda: self.set & self.other) self.assertRaises(TypeError, lambda: self.other & self.set) + if self.otherIsIterable: + self.set.intersection(self.other) + else: + self.assertRaises(TypeError, self.set.intersection, self.other) - def test_sym_difference_update(self): + def test_sym_difference_update_operator(self): try: self.set ^= self.other except TypeError: @@ -535,11 +557,23 @@ class TestOnlySetsInBinaryOps(unittest.TestCase): else: self.fail("expected TypeError") + def test_sym_difference_update(self): + if self.otherIsIterable: + self.set.symmetric_difference_update(self.other) + else: + self.assertRaises(TypeError, + self.set.symmetric_difference_update, + self.other) + def test_sym_difference(self): self.assertRaises(TypeError, lambda: self.set ^ self.other) self.assertRaises(TypeError, lambda: self.other ^ self.set) + if self.otherIsIterable: + self.set.symmetric_difference(self.other) + else: + self.assertRaises(TypeError, self.set.symmetric_difference, self.other) - def test_difference_update(self): + def test_difference_update_operator(self): try: self.set -= self.other except TypeError: @@ -547,16 +581,28 @@ class TestOnlySetsInBinaryOps(unittest.TestCase): else: self.fail("expected TypeError") + def test_difference_update(self): + if self.otherIsIterable: + self.set.difference_update(self.other) + else: + self.assertRaises(TypeError, + self.set.difference_update, + self.other) + def test_difference(self): self.assertRaises(TypeError, lambda: self.set - self.other) self.assertRaises(TypeError, lambda: self.other - self.set) - + if self.otherIsIterable: + self.set.difference(self.other) + else: + self.assertRaises(TypeError, self.set.difference, self.other) #------------------------------------------------------------------------------ class TestOnlySetsNumeric(TestOnlySetsInBinaryOps): def setUp(self): self.set = Set((1, 2, 3)) self.other = 19 + self.otherIsIterable = False #------------------------------------------------------------------------------ @@ -564,6 +610,7 @@ class TestOnlySetsDict(TestOnlySetsInBinaryOps): def setUp(self): self.set = Set((1, 2, 3)) self.other = {1:2, 3:4} + self.otherIsIterable = True #------------------------------------------------------------------------------ @@ -571,6 +618,34 @@ class TestOnlySetsOperator(TestOnlySetsInBinaryOps): def setUp(self): self.set = Set((1, 2, 3)) self.other = operator.add + self.otherIsIterable = False + +#------------------------------------------------------------------------------ + +class TestOnlySetsTuple(TestOnlySetsInBinaryOps): + def setUp(self): + self.set = Set((1, 2, 3)) + self.other = (2, 4, 6) + self.otherIsIterable = True + +#------------------------------------------------------------------------------ + +class TestOnlySetsString(TestOnlySetsInBinaryOps): + def setUp(self): + self.set = Set((1, 2, 3)) + self.other = 'abc' + self.otherIsIterable = True + +#------------------------------------------------------------------------------ + +class TestOnlySetsGenerator(TestOnlySetsInBinaryOps): + def setUp(self): + def gen(): + for i in xrange(0, 10, 2): + yield i + self.set = Set((1, 2, 3)) + self.other = gen() + self.otherIsIterable = True #============================================================================== @@ -625,6 +700,49 @@ class TestCopyingNested(TestCopying): #============================================================================== +class TestIdentities(unittest.TestCase): + def setUp(self): + self.a = Set('abracadabra') + self.b = Set('alacazam') + + def test_binopsVsSubsets(self): + a, b = self.a, self.b + self.assert_(a - b < a) + self.assert_(b - a < b) + self.assert_(a & b < a) + self.assert_(a & b < b) + self.assert_(a | b > a) + self.assert_(a | b > b) + self.assert_(a ^ b < a | b) + + def test_commutativity(self): + a, b = self.a, self.b + self.assertEqual(a&b, b&a) + self.assertEqual(a|b, b|a) + self.assertEqual(a^b, b^a) + if a != b: + self.assertNotEqual(a-b, b-a) + + def test_summations(self): + # check that sums of parts equal the whole + a, b = self.a, self.b + self.assertEqual((a-b)|(a&b)|(b-a), a|b) + self.assertEqual((a&b)|(a^b), a|b) + self.assertEqual(a|(b-a), a|b) + self.assertEqual((a-b)|b, a|b) + self.assertEqual((a-b)|(a&b), a) + self.assertEqual((b-a)|(a&b), b) + self.assertEqual((a-b)|(b-a), a^b) + + def test_exclusion(self): + # check that inverse operations show non-overlap + a, b, zero = self.a, self.b, Set() + self.assertEqual((a-b)&b, zero) + self.assertEqual((b-a)&a, zero) + self.assertEqual((a&b)&(a^b), zero) + +#============================================================================== + libreftest = """ Example from the Library Reference: Doc/lib/libsets.tex @@ -643,7 +761,7 @@ Example from the Library Reference: Doc/lib/libsets.tex Set(['Jack', 'Jane', 'Janice', 'John', 'Marvin']) >>> employees.issuperset(engineers) # superset test False ->>> employees.update(engineers) # update from another set +>>> employees.union_update(engineers) # update from another set >>> employees.issuperset(engineers) True >>> for group in [engineers, programmers, managers, employees]: @@ -680,11 +798,15 @@ def test_main(verbose=None): TestOnlySetsNumeric, TestOnlySetsDict, TestOnlySetsOperator, + TestOnlySetsTuple, + TestOnlySetsString, + TestOnlySetsGenerator, TestCopyingEmpty, TestCopyingSingleton, TestCopyingTriple, TestCopyingTuple, - TestCopyingNested + TestCopyingNested, + TestIdentities, ) test_support.run_doctest(test_sets, verbose) diff --git a/Misc/NEWS b/Misc/NEWS index a3a62be..8f24d85 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -30,7 +30,10 @@ Extension modules Library ------- -- sets.py now runs under Py2.2 +- sets.py now runs under Py2.2. In addition, the argument restrictions + for most set methods (but not the operators) have been relaxed to + allow any iterable. Also the Set.update() has been deprecated because + it duplicates Set.union_update(). - random.seed() with no arguments or None uses time.time() as a default seed. Modified to match Py2.2 behavior and use fractional seconds so -- cgit v0.12