diff options
author | Raymond Hettinger <python@rcn.com> | 2003-08-17 08:34:09 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2003-08-17 08:34:09 (GMT) |
commit | 6a1801271aa4011cf26e7a64b52f6be10997f267 (patch) | |
tree | 3aaa9dc996dcfd676d809bf585cd498420b3bcbf /Lib/sets.py | |
parent | 236ffba40005039cfbc0bd7475345ef7fedf86c5 (diff) | |
download | cpython-6a1801271aa4011cf26e7a64b52f6be10997f267.zip cpython-6a1801271aa4011cf26e7a64b52f6be10997f267.tar.gz cpython-6a1801271aa4011cf26e7a64b52f6be10997f267.tar.bz2 |
Improvements to set.py:
* Relaxed the argument restrictions for non-operator methods. They now
allow any iterable instead of requiring a set. This makes the module
a little easier to use and paves the way for an efficient C
implementation which can take better advantage of iterable arguments
while screening out immutables.
* Deprecated Set.update() because it now duplicates Set.union_update()
* Adapted the tests and docs to include the above changes.
* Added more test coverage including testing identities and checking
to make sure non-restartable generators work as arguments.
Will backport to Py2.3.1 so that the interface remains consistent
across versions. The deprecation of update() will be changed to
a FutureWarning.
Diffstat (limited to 'Lib/sets.py')
-rw-r--r-- | Lib/sets.py | 95 |
1 files changed, 57 insertions, 38 deletions
diff --git a/Lib/sets.py b/Lib/sets.py index 32eb0aa..3fe8c8f 100644 --- a/Lib/sets.py +++ b/Lib/sets.py @@ -196,17 +196,16 @@ class BaseSet(object): """ if not isinstance(other, BaseSet): return NotImplemented - result = self.__class__() - result._data = self._data.copy() - result._data.update(other._data) - return result + return self.union(other) def union(self, other): """Return the union of two sets as a new set. (I.e. all elements that are in either set.) """ - return self | other + result = self.__class__(self) + result._update(other) + return result def __and__(self, other): """Return the intersection of two sets as a new set. @@ -215,19 +214,21 @@ class BaseSet(object): """ if not isinstance(other, BaseSet): return NotImplemented - if len(self) <= len(other): - little, big = self, other - else: - little, big = other, self - common = ifilter(big._data.has_key, little) - return self.__class__(common) + return self.intersection(other) def intersection(self, other): """Return the intersection of two sets as a new set. (I.e. all elements that are in both sets.) """ - return self & other + if not isinstance(other, BaseSet): + other = Set(other) + if len(self) <= len(other): + little, big = self, other + else: + little, big = other, self + common = ifilter(big._data.has_key, little) + return self.__class__(common) def __xor__(self, other): """Return the symmetric difference of two sets as a new set. @@ -236,24 +237,27 @@ class BaseSet(object): """ if not isinstance(other, BaseSet): return NotImplemented + return self.symmetric_difference(other) + + def symmetric_difference(self, other): + """Return the symmetric difference of two sets as a new set. + + (I.e. all elements that are in exactly one of the sets.) + """ result = self.__class__() data = result._data value = True selfdata = self._data - otherdata = other._data + try: + otherdata = other._data + except AttributeError: + otherdata = Set(other)._data for elt in ifilterfalse(otherdata.has_key, selfdata): data[elt] = value for elt in ifilterfalse(selfdata.has_key, otherdata): data[elt] = value return result - def symmetric_difference(self, other): - """Return the symmetric difference of two sets as a new set. - - (I.e. all elements that are in exactly one of the sets.) - """ - return self ^ other - def __sub__(self, other): """Return the difference of two sets as a new Set. @@ -261,19 +265,23 @@ class BaseSet(object): """ if not isinstance(other, BaseSet): return NotImplemented - result = self.__class__() - data = result._data - value = True - for elt in ifilterfalse(other._data.has_key, self): - data[elt] = value - return result + return self.difference(other) def difference(self, other): """Return the difference of two sets as a new Set. (I.e. all elements that are in this set and not in the other.) """ - return self - other + result = self.__class__() + data = result._data + try: + otherdata = other._data + except AttributeError: + otherdata = Set(other)._data + value = True + for elt in ifilterfalse(otherdata.has_key, self): + data[elt] = value + return result # Membership test @@ -441,7 +449,7 @@ class Set(BaseSet): def union_update(self, other): """Update a set with the union of itself and another.""" - self |= other + self._update(other) def __iand__(self, other): """Update a set with the intersection of itself and another.""" @@ -451,40 +459,51 @@ class Set(BaseSet): def intersection_update(self, other): """Update a set with the intersection of itself and another.""" - self &= other + if isinstance(other, BaseSet): + self &= other + else: + self._data = (self.intersection(other))._data def __ixor__(self, other): """Update a set with the symmetric difference of itself and another.""" self._binary_sanity_check(other) + self.symmetric_difference_update(other) + return self + + def symmetric_difference_update(self, other): + """Update a set with the symmetric difference of itself and another.""" data = self._data value = True + if not isinstance(other, BaseSet): + other = Set(other) for elt in other: if elt in data: del data[elt] else: data[elt] = value - return self - - def symmetric_difference_update(self, other): - """Update a set with the symmetric difference of itself and another.""" - self ^= other def __isub__(self, other): """Remove all elements of another set from this set.""" self._binary_sanity_check(other) - data = self._data - for elt in ifilter(data.has_key, other): - del data[elt] + self.difference_update(other) return self def difference_update(self, other): """Remove all elements of another set from this set.""" - self -= other + data = self._data + if not isinstance(other, BaseSet): + other = Set(other) + for elt in ifilter(data.has_key, other): + del data[elt] # Python dict-like mass mutations: update, clear def update(self, iterable): """Add all values from an iterable (such as a list or file).""" + import warnings + warnings.warn("The update() method deprecated; " + "Use union_update() instead", + DeprecationWarning) self._update(iterable) def clear(self): |