summaryrefslogtreecommitdiffstats
path: root/Lib/sets.py
diff options
context:
space:
mode:
authorRaymond Hettinger <python@rcn.com>2003-08-17 08:34:09 (GMT)
committerRaymond Hettinger <python@rcn.com>2003-08-17 08:34:09 (GMT)
commit6a1801271aa4011cf26e7a64b52f6be10997f267 (patch)
tree3aaa9dc996dcfd676d809bf585cd498420b3bcbf /Lib/sets.py
parent236ffba40005039cfbc0bd7475345ef7fedf86c5 (diff)
downloadcpython-6a1801271aa4011cf26e7a64b52f6be10997f267.zip
cpython-6a1801271aa4011cf26e7a64b52f6be10997f267.tar.gz
cpython-6a1801271aa4011cf26e7a64b52f6be10997f267.tar.bz2
Improvements to set.py:
* Relaxed the argument restrictions for non-operator methods. They now allow any iterable instead of requiring a set. This makes the module a little easier to use and paves the way for an efficient C implementation which can take better advantage of iterable arguments while screening out immutables. * Deprecated Set.update() because it now duplicates Set.union_update() * Adapted the tests and docs to include the above changes. * Added more test coverage including testing identities and checking to make sure non-restartable generators work as arguments. Will backport to Py2.3.1 so that the interface remains consistent across versions. The deprecation of update() will be changed to a FutureWarning.
Diffstat (limited to 'Lib/sets.py')
-rw-r--r--Lib/sets.py95
1 files changed, 57 insertions, 38 deletions
diff --git a/Lib/sets.py b/Lib/sets.py
index 32eb0aa..3fe8c8f 100644
--- a/Lib/sets.py
+++ b/Lib/sets.py
@@ -196,17 +196,16 @@ class BaseSet(object):
"""
if not isinstance(other, BaseSet):
return NotImplemented
- result = self.__class__()
- result._data = self._data.copy()
- result._data.update(other._data)
- return result
+ return self.union(other)
def union(self, other):
"""Return the union of two sets as a new set.
(I.e. all elements that are in either set.)
"""
- return self | other
+ result = self.__class__(self)
+ result._update(other)
+ return result
def __and__(self, other):
"""Return the intersection of two sets as a new set.
@@ -215,19 +214,21 @@ class BaseSet(object):
"""
if not isinstance(other, BaseSet):
return NotImplemented
- if len(self) <= len(other):
- little, big = self, other
- else:
- little, big = other, self
- common = ifilter(big._data.has_key, little)
- return self.__class__(common)
+ return self.intersection(other)
def intersection(self, other):
"""Return the intersection of two sets as a new set.
(I.e. all elements that are in both sets.)
"""
- return self & other
+ if not isinstance(other, BaseSet):
+ other = Set(other)
+ if len(self) <= len(other):
+ little, big = self, other
+ else:
+ little, big = other, self
+ common = ifilter(big._data.has_key, little)
+ return self.__class__(common)
def __xor__(self, other):
"""Return the symmetric difference of two sets as a new set.
@@ -236,24 +237,27 @@ class BaseSet(object):
"""
if not isinstance(other, BaseSet):
return NotImplemented
+ return self.symmetric_difference(other)
+
+ def symmetric_difference(self, other):
+ """Return the symmetric difference of two sets as a new set.
+
+ (I.e. all elements that are in exactly one of the sets.)
+ """
result = self.__class__()
data = result._data
value = True
selfdata = self._data
- otherdata = other._data
+ try:
+ otherdata = other._data
+ except AttributeError:
+ otherdata = Set(other)._data
for elt in ifilterfalse(otherdata.has_key, selfdata):
data[elt] = value
for elt in ifilterfalse(selfdata.has_key, otherdata):
data[elt] = value
return result
- def symmetric_difference(self, other):
- """Return the symmetric difference of two sets as a new set.
-
- (I.e. all elements that are in exactly one of the sets.)
- """
- return self ^ other
-
def __sub__(self, other):
"""Return the difference of two sets as a new Set.
@@ -261,19 +265,23 @@ class BaseSet(object):
"""
if not isinstance(other, BaseSet):
return NotImplemented
- result = self.__class__()
- data = result._data
- value = True
- for elt in ifilterfalse(other._data.has_key, self):
- data[elt] = value
- return result
+ return self.difference(other)
def difference(self, other):
"""Return the difference of two sets as a new Set.
(I.e. all elements that are in this set and not in the other.)
"""
- return self - other
+ result = self.__class__()
+ data = result._data
+ try:
+ otherdata = other._data
+ except AttributeError:
+ otherdata = Set(other)._data
+ value = True
+ for elt in ifilterfalse(otherdata.has_key, self):
+ data[elt] = value
+ return result
# Membership test
@@ -441,7 +449,7 @@ class Set(BaseSet):
def union_update(self, other):
"""Update a set with the union of itself and another."""
- self |= other
+ self._update(other)
def __iand__(self, other):
"""Update a set with the intersection of itself and another."""
@@ -451,40 +459,51 @@ class Set(BaseSet):
def intersection_update(self, other):
"""Update a set with the intersection of itself and another."""
- self &= other
+ if isinstance(other, BaseSet):
+ self &= other
+ else:
+ self._data = (self.intersection(other))._data
def __ixor__(self, other):
"""Update a set with the symmetric difference of itself and another."""
self._binary_sanity_check(other)
+ self.symmetric_difference_update(other)
+ return self
+
+ def symmetric_difference_update(self, other):
+ """Update a set with the symmetric difference of itself and another."""
data = self._data
value = True
+ if not isinstance(other, BaseSet):
+ other = Set(other)
for elt in other:
if elt in data:
del data[elt]
else:
data[elt] = value
- return self
-
- def symmetric_difference_update(self, other):
- """Update a set with the symmetric difference of itself and another."""
- self ^= other
def __isub__(self, other):
"""Remove all elements of another set from this set."""
self._binary_sanity_check(other)
- data = self._data
- for elt in ifilter(data.has_key, other):
- del data[elt]
+ self.difference_update(other)
return self
def difference_update(self, other):
"""Remove all elements of another set from this set."""
- self -= other
+ data = self._data
+ if not isinstance(other, BaseSet):
+ other = Set(other)
+ for elt in ifilter(data.has_key, other):
+ del data[elt]
# Python dict-like mass mutations: update, clear
def update(self, iterable):
"""Add all values from an iterable (such as a list or file)."""
+ import warnings
+ warnings.warn("The update() method deprecated; "
+ "Use union_update() instead",
+ DeprecationWarning)
self._update(iterable)
def clear(self):