diff options
author | Guido van Rossum <guido@python.org> | 2002-08-19 16:19:15 (GMT) |
---|---|---|
committer | Guido van Rossum <guido@python.org> | 2002-08-19 16:19:15 (GMT) |
commit | d6cf3af8f79f19bf93c1db4d55772f68d59ab454 (patch) | |
tree | 8948f073b544f90d281b96ea70edccd57a57116c | |
parent | c588e9041aea212fe2b5fad9254824d12f804c3e (diff) | |
download | cpython-d6cf3af8f79f19bf93c1db4d55772f68d59ab454.zip cpython-d6cf3af8f79f19bf93c1db4d55772f68d59ab454.tar.gz cpython-d6cf3af8f79f19bf93c1db4d55772f68d59ab454.tar.bz2 |
Set classes and their unit tests, from sandbox.
-rw-r--r-- | Lib/sets.py | 529 | ||||
-rw-r--r-- | Lib/test/test_sets.py | 568 |
2 files changed, 1097 insertions, 0 deletions
diff --git a/Lib/sets.py b/Lib/sets.py new file mode 100644 index 0000000..1072dd2 --- /dev/null +++ b/Lib/sets.py @@ -0,0 +1,529 @@ +"""Classes to represent arbitrary sets (including sets of sets). + +This module implements sets using dictionaries whose values are +ignored. The usual operations (union, intersection, deletion, etc.) +are provided as both methods and operators. + +The following classes are provided: + +BaseSet -- All the operations common to both mutable and immutable + sets. This is an abstract class, not meant to be directly + instantiated. + +Set -- Mutable sets, subclass of BaseSet; not hashable. + +ImmutableSet -- Immutable sets, subclass of BaseSet; hashable. + An iterable argument is mandatory to create an ImmutableSet. + +_TemporarilyImmutableSet -- Not a subclass of BaseSet: just a wrapper + around a Set, hashable, giving the same hash value as the + immutable set equivalent would have. Do not use this class + directly. + +Only hashable objects can be added to a Set. In particular, you cannot +really add a Set as an element to another Set; if you try, what is +actuallly added is an ImmutableSet built from it (it compares equal to +the one you tried adding). + +When you ask if `x in y' where x is a Set and y is a Set or +ImmutableSet, x is wrapped into a _TemporarilyImmutableSet z, and +what's tested is actually `z in y'. + +""" + +# Code history: +# +# - Greg V. Wilson wrote the first version, using a different approach +# to the mutable/immutable problem, and inheriting from dict. +# +# - Alex Martelli modified Greg's version to implement the current +# Set/ImmutableSet approach, and make the data an attribute. +# +# - Guido van Rossum rewrote much of the code, made some API changes, +# and cleaned up the docstrings. + + +__all__ = ['BaseSet', 'Set', 'ImmutableSet'] + + +class BaseSet(object): + """Common base class for mutable and immutable sets.""" + + __slots__ = ['_data'] + + # Constructor + + def __init__(self, seq=None): + """Construct a set, optionally initializing it from a sequence.""" + self._data = {} + if seq is not None: + # I don't know a faster way to do this in pure Python. + # Custom code written in C only did it 65% faster, + # preallocating the dict to len(seq); without + # preallocation it was only 25% faster. So the speed of + # this Python code is respectable. Just copying True into + # a local variable is responsible for a 7-8% speedup. + data = self._data + value = True + for key in seq: + data[key] = value + + # Standard protocols: __len__, __repr__, __str__, __iter__ + + def __len__(self): + """Return the number of elements of a set.""" + return len(self._data) + + def __repr__(self): + """Return string representation of a set. + + This looks like 'Set([<list of elements>])'. + """ + return self._repr() + + # __str__ is the same as __repr__ + __str__ = __repr__ + + def _repr(self, sorted=False): + elements = self._data.keys() + if sorted: + elements.sort() + return '%s(%r)' % (self.__class__.__name__, elements) + + def __iter__(self): + """Return an iterator over the elements or a set. + + This is the keys iterator for the underlying dict. + """ + return self._data.iterkeys() + + # Comparisons. Ordering is determined by the ordering of the + # underlying dicts (which is consistent though unpredictable). + + def __lt__(self, other): + self._binary_sanity_check(other) + return self._data < other._data + + def __le__(self, other): + self._binary_sanity_check(other) + return self._data <= other._data + + def __eq__(self, other): + self._binary_sanity_check(other) + return self._data == other._data + + def __ne__(self, other): + self._binary_sanity_check(other) + return self._data != other._data + + def __gt__(self, other): + self._binary_sanity_check(other) + return self._data > other._data + + def __ge__(self, other): + self._binary_sanity_check(other) + return self._data >= other._data + + # Copying operations + + def copy(self): + """Return a shallow copy of a set.""" + return self.__class__(self) + + __copy__ = copy # For the copy module + + def __deepcopy__(self, memo): + """Return a deep copy of a set; used by copy module.""" + # This pre-creates the result and inserts it in the memo + # early, in case the deep copy recurses into another reference + # to this same set. A set can't be an element of itself, but + # it can certainly contain an object that has a reference to + # itself. + from copy import deepcopy + result = self.__class__([]) + memo[id(self)] = result + data = result._data + value = True + for elt in self: + data[deepcopy(elt, memo)] = value + return result + + # Standard set operations: union, intersection, both differences + + def union(self, other): + """Return the union of two sets as a new set. + + (I.e. all elements that are in either set.) + """ + self._binary_sanity_check(other) + result = self.__class__(self._data) + result._data.update(other._data) + return result + + __or__ = union + + def intersection(self, other): + """Return the intersection of two sets as a new set. + + (I.e. all elements that are in both sets.) + """ + self._binary_sanity_check(other) + if len(self) <= len(other): + little, big = self, other + else: + little, big = other, self + result = self.__class__([]) + data = result._data + value = True + for elt in little: + if elt in big: + data[elt] = value + return result + + __and__ = intersection + + def symmetric_difference(self, other): + """Return the symmetric difference of two sets as a new set. + + (I.e. all elements that are in exactly one of the sets.) + """ + self._binary_sanity_check(other) + result = self.__class__([]) + data = result._data + value = True + for elt in self: + if elt not in other: + data[elt] = value + for elt in other: + if elt not in self: + data[elt] = value + return result + + __xor__ = symmetric_difference + + def difference(self, other): + """Return the difference of two sets as a new Set. + + (I.e. all elements that are in this set and not in the other.) + """ + self._binary_sanity_check(other) + result = self.__class__([]) + data = result._data + value = True + for elt in self: + if elt not in other: + data[elt] = value + return result + + __sub__ = difference + + # Membership test + + def __contains__(self, element): + """Report whether an element is a member of a set. + + (Called in response to the expression `element in self'.) + """ + try: + transform = element._as_temporarily_immutable + except AttributeError: + pass + else: + element = transform() + return element in self._data + + # Subset and superset test + + def issubset(self, other): + """Report whether another set contains this set.""" + self._binary_sanity_check(other) + for elt in self: + if elt not in other: + return False + return True + + def issuperset(self, other): + """Report whether this set contains another set.""" + self._binary_sanity_check(other) + for elt in other: + if elt not in self: + return False + return True + + # Assorted helpers + + def _binary_sanity_check(self, other): + # Check that the other argument to a binary operation is also + # a set, raising a TypeError otherwise. + if not isinstance(other, BaseSet): + raise TypeError, "Binary operation only permitted between sets" + + def _compute_hash(self): + # Calculate hash code for a set by xor'ing the hash codes of + # the elements. This algorithm ensures that the hash code + # does not depend on the order in which elements are added to + # the code. This is not called __hash__ because a BaseSet + # should not be hashable; only an ImmutableSet is hashable. + result = 0 + for elt in self: + result ^= hash(elt) + return result + + +class ImmutableSet(BaseSet): + """Immutable set class.""" + + __slots__ = ['_hash'] + + # BaseSet + hashing + + def __init__(self, seq): + """Construct an immutable set from a sequence.""" + # Override the constructor to make 'seq' a required argument + BaseSet.__init__(self, seq) + self._hashcode = None + + def __hash__(self): + if self._hashcode is None: + self._hashcode = self._compute_hash() + return self._hashcode + + +class Set(BaseSet): + """ Mutable set class.""" + + __slots__ = [] + + # BaseSet + operations requiring mutability; no hashing + + # In-place union, intersection, differences + + def union_update(self, other): + """Update a set with the union of itself and another.""" + self._binary_sanity_check(other) + self._data.update(other._data) + return self + + __ior__ = union_update + + def intersection_update(self, other): + """Update a set with the intersection of itself and another.""" + self._binary_sanity_check(other) + for elt in self._data.keys(): + if elt not in other: + del self._data[elt] + return self + + __iand__ = intersection_update + + def symmetric_difference_update(self, other): + """Update a set with the symmetric difference of itself and another.""" + self._binary_sanity_check(other) + data = self._data + value = True + for elt in other: + if elt in data: + del data[elt] + else: + data[elt] = value + return self + + __ixor__ = symmetric_difference_update + + def difference_update(self, other): + """Remove all elements of another set from this set.""" + self._binary_sanity_check(other) + data = self._data + for elt in other: + if elt in data: + del data[elt] + return self + + __isub__ = difference_update + + # Python dict-like mass mutations: update, clear + + def update(self, iterable): + """Add all values from an iterable (such as a list or file).""" + data = self._data + value = True + for elt in iterable: + try: + transform = elt._as_immutable + except AttributeError: + pass + else: + elt = transform() + data[elt] = value + + def clear(self): + """Remove all elements from this set.""" + self._data.clear() + + # Single-element mutations: add, remove, discard + + def add(self, element): + """Add an element to a set. + + This has no effect if the element is already present. + """ + try: + transform = element._as_immutable + except AttributeError: + pass + else: + element = transform() + self._data[element] = True + + def remove(self, element): + """Remove an element from a set; it must be a member. + + If the element is not a member, raise a KeyError. + """ + try: + transform = element._as_temporarily_immutable + except AttributeError: + pass + else: + element = transform() + del self._data[element] + + def discard(self, element): + """Remove an element from a set if it is a member. + + If the element is not a member, do nothing. + """ + try: + del self._data[element] + except KeyError: + pass + + def popitem(self): + """Remove and return a randomly-chosen set element.""" + return self._data.popitem()[0] + + def _as_immutable(self): + # Return a copy of self as an immutable set + return ImmutableSet(self) + + def _as_temporarily_immutable(self): + # Return self wrapped in a temporarily immutable set + return _TemporarilyImmutableSet(self) + + +class _TemporarilyImmutableSet(object): + # Wrap a mutable set as if it was temporarily immutable. + # This only supplies hashing and equality comparisons. + + _hashcode = None + + def __init__(self, set): + self._set = set + + def __hash__(self): + if self._hashcode is None: + self._hashcode = self._set._compute_hash() + return self._hashcode + + def __eq__(self, other): + return self._set == other + + def __ne__(self, other): + return self._set != other + + +# Rudimentary self-tests + +def _test(): + + # Empty set + red = Set() + assert `red` == "Set([])", "Empty set: %s" % `red` + + # Unit set + green = Set((0,)) + assert `green` == "Set([0])", "Unit set: %s" % `green` + + # 3-element set + blue = Set([0, 1, 2]) + assert blue._repr(True) == "Set([0, 1, 2])", "3-element set: %s" % `blue` + + # 2-element set with other values + black = Set([0, 5]) + assert black._repr(True) == "Set([0, 5])", "2-element set: %s" % `black` + + # All elements from all sets + white = Set([0, 1, 2, 5]) + assert white._repr(True) == "Set([0, 1, 2, 5])", "4-element set: %s" % `white` + + # Add element to empty set + red.add(9) + assert `red` == "Set([9])", "Add to empty set: %s" % `red` + + # Remove element from unit set + red.remove(9) + assert `red` == "Set([])", "Remove from unit set: %s" % `red` + + # Remove element from empty set + try: + red.remove(0) + assert 0, "Remove element from empty set: %s" % `red` + except LookupError: + pass + + # Length + assert len(red) == 0, "Length of empty set" + assert len(green) == 1, "Length of unit set" + assert len(blue) == 3, "Length of 3-element set" + + # Compare + assert green == Set([0]), "Equality failed" + assert green != Set([1]), "Inequality failed" + + # Union + assert blue | red == blue, "Union non-empty with empty" + assert red | blue == blue, "Union empty with non-empty" + assert green | blue == blue, "Union non-empty with non-empty" + assert blue | black == white, "Enclosing union" + + # Intersection + assert blue & red == red, "Intersect non-empty with empty" + assert red & blue == red, "Intersect empty with non-empty" + assert green & blue == green, "Intersect non-empty with non-empty" + assert blue & black == green, "Enclosing intersection" + + # Symmetric difference + assert red ^ green == green, "Empty symdiff non-empty" + assert green ^ blue == Set([1, 2]), "Non-empty symdiff" + assert white ^ white == red, "Self symdiff" + + # Difference + assert red - green == red, "Empty - non-empty" + assert blue - red == blue, "Non-empty - empty" + assert white - black == Set([1, 2]), "Non-empty - non-empty" + + # In-place union + orange = Set([]) + orange |= Set([1]) + assert orange == Set([1]), "In-place union" + + # In-place intersection + orange = Set([1, 2]) + orange &= Set([2]) + assert orange == Set([2]), "In-place intersection" + + # In-place difference + orange = Set([1, 2, 3]) + orange -= Set([2, 4]) + assert orange == Set([1, 3]), "In-place difference" + + # In-place symmetric difference + orange = Set([1, 2, 3]) + orange ^= Set([3, 4]) + assert orange == Set([1, 2, 4]), "In-place symmetric difference" + + print "All tests passed" + + +if __name__ == "__main__": + _test() diff --git a/Lib/test/test_sets.py b/Lib/test/test_sets.py new file mode 100644 index 0000000..6c72b0e --- /dev/null +++ b/Lib/test/test_sets.py @@ -0,0 +1,568 @@ +#!/usr/bin/env python + +import unittest, operator, copy +from sets import Set, ImmutableSet +from test import test_support + +empty_set = Set() + +#============================================================================== + +class TestBasicOps(unittest.TestCase): + + def test_repr(self): + if self.repr is not None: + assert `self.set` == self.repr, "Wrong representation for " + self.case + + def test_length(self): + assert len(self.set) == self.length, "Wrong length for " + self.case + + def test_self_equality(self): + assert self.set == self.set, "Self-equality failed for " + self.case + + def test_equivalent_equality(self): + assert self.set == self.dup, "Equivalent equality failed for " + self.case + + def test_copy(self): + assert self.set.copy() == self.dup, "Copy and comparison failed for " + self.case + + def test_self_union(self): + result = self.set | self.set + assert result == self.dup, "Self-union failed for " + self.case + + def test_empty_union(self): + result = self.set | empty_set + assert result == self.dup, "Union with empty failed for " + self.case + + def test_union_empty(self): + result = empty_set | self.set + assert result == self.dup, "Union with empty failed for " + self.case + + def test_self_intersection(self): + result = self.set & self.set + assert result == self.dup, "Self-intersection failed for " + self.case + + def test_empty_intersection(self): + result = self.set & empty_set + assert result == empty_set, "Intersection with empty failed for " + self.case + + def test_intersection_empty(self): + result = empty_set & self.set + assert result == empty_set, "Intersection with empty failed for " + self.case + + def test_self_symmetric_difference(self): + result = self.set ^ self.set + assert result == empty_set, "Self-symdiff failed for " + self.case + + def checkempty_symmetric_difference(self): + result = self.set ^ empty_set + assert result == self.set, "Symdiff with empty failed for " + self.case + + def test_self_difference(self): + result = self.set - self.set + assert result == empty_set, "Self-difference failed for " + self.case + + def test_empty_difference(self): + result = self.set - empty_set + assert result == self.dup, "Difference with empty failed for " + self.case + + def test_empty_difference_rev(self): + result = empty_set - self.set + assert result == empty_set, "Difference from empty failed for " + self.case + + def test_iteration(self): + for v in self.set: + assert v in self.values, "Missing item in iteration for " + self.case + +#------------------------------------------------------------------------------ + +class TestBasicOpsEmpty(TestBasicOps): + def setUp(self): + self.case = "empty set" + self.values = [] + self.set = Set(self.values) + self.dup = Set(self.values) + self.length = 0 + self.repr = "Set([])" + +#------------------------------------------------------------------------------ + +class TestBasicOpsSingleton(TestBasicOps): + def setUp(self): + self.case = "unit set (number)" + self.values = [3] + self.set = Set(self.values) + self.dup = Set(self.values) + self.length = 1 + self.repr = "Set([3])" + + def test_in(self): + assert 3 in self.set, "Valueship for unit set" + + def test_not_in(self): + assert 2 not in self.set, "Non-valueship for unit set" + +#------------------------------------------------------------------------------ + +class TestBasicOpsTuple(TestBasicOps): + def setUp(self): + self.case = "unit set (tuple)" + self.values = [(0, "zero")] + self.set = Set(self.values) + self.dup = Set(self.values) + self.length = 1 + self.repr = "Set([(0, 'zero')])" + + def test_in(self): + assert (0, "zero") in self.set, "Valueship for tuple set" + + def test_not_in(self): + assert 9 not in self.set, "Non-valueship for tuple set" + +#------------------------------------------------------------------------------ + +class TestBasicOpsTriple(TestBasicOps): + def setUp(self): + self.case = "triple set" + self.values = [0, "zero", operator.add] + self.set = Set(self.values) + self.dup = Set(self.values) + self.length = 3 + self.repr = None + +#============================================================================== + +class TestBinaryOps(unittest.TestCase): + def setUp(self): + self.set = Set((2, 4, 6)) + + def test_union_subset(self): + result = self.set | Set([2]) + assert result == Set((2, 4, 6)), "Subset union" + + def test_union_superset(self): + result = self.set | Set([2, 4, 6, 8]) + assert result == Set([2, 4, 6, 8]), "Superset union" + + def test_union_overlap(self): + result = self.set | Set([3, 4, 5]) + assert result == Set([2, 3, 4, 5, 6]), "Overlapping union" + + def test_union_non_overlap(self): + result = self.set | Set([8]) + assert result == Set([2, 4, 6, 8]), "Non-overlapping union" + + def test_intersection_subset(self): + result = self.set & Set((2, 4)) + assert result == Set((2, 4)), "Subset intersection" + + def test_intersection_superset(self): + result = self.set & Set([2, 4, 6, 8]) + assert result == Set([2, 4, 6]), "Superset intersection" + + def test_intersection_overlap(self): + result = self.set & Set([3, 4, 5]) + assert result == Set([4]), "Overlapping intersection" + + def test_intersection_non_overlap(self): + result = self.set & Set([8]) + assert result == empty_set, "Non-overlapping intersection" + + def test_sym_difference_subset(self): + result = self.set ^ Set((2, 4)) + assert result == Set([6]), "Subset symmetric difference" + + def test_sym_difference_superset(self): + result = self.set ^ Set((2, 4, 6, 8)) + assert result == Set([8]), "Superset symmetric difference" + + def test_sym_difference_overlap(self): + result = self.set ^ Set((3, 4, 5)) + assert result == Set([2, 3, 5, 6]), "Overlapping symmetric difference" + + def test_sym_difference_non_overlap(self): + result = self.set ^ Set([8]) + assert result == Set([2, 4, 6, 8]), "Non-overlapping symmetric difference" + +#============================================================================== + +class TestUpdateOps(unittest.TestCase): + def setUp(self): + self.set = Set((2, 4, 6)) + + def test_union_subset(self): + self.set |= Set([2]) + assert self.set == Set((2, 4, 6)), "Subset union" + + def test_union_superset(self): + self.set |= Set([2, 4, 6, 8]) + assert self.set == Set([2, 4, 6, 8]), "Superset union" + + def test_union_overlap(self): + self.set |= Set([3, 4, 5]) + assert self.set == Set([2, 3, 4, 5, 6]), "Overlapping union" + + def test_union_non_overlap(self): + self.set |= Set([8]) + assert self.set == Set([2, 4, 6, 8]), "Non-overlapping union" + + def test_intersection_subset(self): + self.set &= Set((2, 4)) + assert self.set == Set((2, 4)), "Subset intersection" + + def test_intersection_superset(self): + self.set &= Set([2, 4, 6, 8]) + assert self.set == Set([2, 4, 6]), "Superset intersection" + + def test_intersection_overlap(self): + self.set &= Set([3, 4, 5]) + assert self.set == Set([4]), "Overlapping intersection" + + def test_intersection_non_overlap(self): + self.set &= Set([8]) + assert self.set == empty_set, "Non-overlapping intersection" + + def test_sym_difference_subset(self): + self.set ^= Set((2, 4)) + assert self.set == Set([6]), "Subset symmetric difference" + + def test_sym_difference_superset(self): + self.set ^= Set((2, 4, 6, 8)) + assert self.set == Set([8]), "Superset symmetric difference" + + def test_sym_difference_overlap(self): + self.set ^= Set((3, 4, 5)) + assert self.set == Set([2, 3, 5, 6]), "Overlapping symmetric difference" + + def test_sym_difference_non_overlap(self): + self.set ^= Set([8]) + assert self.set == Set([2, 4, 6, 8]), "Non-overlapping symmetric difference" + +#============================================================================== + +class TestMutate(unittest.TestCase): + def setUp(self): + self.values = ["a", "b", "c"] + self.set = Set(self.values) + + def test_add_present(self): + self.set.add("c") + assert self.set == Set(("a", "b", "c")), "Adding present element" + + def test_add_absent(self): + self.set.add("d") + assert self.set == Set(("a", "b", "c", "d")), "Adding missing element" + + def test_add_until_full(self): + tmp = Set() + expected_len = 0 + for v in self.values: + tmp.add(v) + expected_len += 1 + assert len(tmp) == expected_len, "Adding values one by one to temporary" + assert tmp == self.set, "Adding values one by one" + + def test_remove_present(self): + self.set.remove("b") + assert self.set == Set(("a", "c")), "Removing present element" + + def test_remove_absent(self): + try: + self.set.remove("d") + assert 0, "Removing missing element" + except LookupError: + pass + + def test_remove_until_empty(self): + expected_len = len(self.set) + for v in self.values: + self.set.remove(v) + expected_len -= 1 + assert len(self.set) == expected_len, "Removing values one by one" + + def test_discard_present(self): + self.set.discard("c") + assert self.set == Set(("a", "b")), "Discarding present element" + + def test_discard_absent(self): + self.set.discard("d") + assert self.set == Set(("a", "b", "c")), "Discarding missing element" + + def test_clear(self): + self.set.clear() + assert len(self.set) == 0, "Clearing set" + + def test_popitem(self): + popped = {} + while self.set: + popped[self.set.popitem()] = None + assert len(popped) == len(self.values), "Popping items" + for v in self.values: + assert v in popped, "Popping items" + + def test_update_empty_tuple(self): + self.set.update(()) + assert self.set == Set(self.values), "Updating with empty tuple" + + def test_update_unit_tuple_overlap(self): + self.set.update(("a",)) + assert self.set == Set(self.values), "Updating with overlapping unit tuple" + + def test_update_unit_tuple_non_overlap(self): + self.set.update(("a", "z")) + assert self.set == Set(self.values + ["z"]), "Updating with non-overlapping unit tuple" + +#============================================================================== + +class TestSubsets(unittest.TestCase): + + def test_issubset(self): + result = self.left.issubset(self.right) + if "<" in self.cases: + assert result, "subset: " + self.name + else: + assert not result, "non-subset: " + self.name + +#------------------------------------------------------------------------------ + +class TestSubsetEqualEmpty(TestSubsets): + def setUp(self): + self.left = Set() + self.right = Set() + self.name = "both empty" + self.cases = "<>" + +#------------------------------------------------------------------------------ + +class TestSubsetEqualNonEmpty(TestSubsets): + def setUp(self): + self.left = Set([1, 2]) + self.right = Set([1, 2]) + self.name = "equal pair" + self.cases = "<>" + +#------------------------------------------------------------------------------ + +class TestSubsetEmptyNonEmpty(TestSubsets): + def setUp(self): + self.left = Set() + self.right = Set([1, 2]) + self.name = "one empty, one non-empty" + self.cases = "<" + +#------------------------------------------------------------------------------ + +class TestSubsetPartial(TestSubsets): + def setUp(self): + self.left = Set([1]) + self.right = Set([1, 2]) + self.name = "one a non-empty subset of other" + self.cases = "<" + +#------------------------------------------------------------------------------ + +class TestSubsetNonOverlap(TestSubsets): + def setUp(self): + self.left = Set([1]) + self.right = Set([2]) + self.name = "neither empty, neither contains" + self.cases = "" + +#============================================================================== + +class TestOnlySetsInBinaryOps(unittest.TestCase): + + def test_cmp(self): + try: + self.other < self.set + assert 0, "Comparison with non-set on left" + except TypeError: + pass + try: + self.set >= self.other + assert 0, "Comparison with non-set on right" + except TypeError: + pass + + def test_union_update(self): + try: + self.set |= self.other + assert 0, "Union update with non-set" + except TypeError: + pass + + def test_union(self): + try: + self.other | self.set + assert 0, "Union with non-set on left" + except TypeError: + pass + try: + self.set | self.other + assert 0, "Union with non-set on right" + except TypeError: + pass + + def test_intersection_update(self): + try: + self.set &= self.other + assert 0, "Intersection update with non-set" + except TypeError: + pass + + def test_intersection(self): + try: + self.other & self.set + assert 0, "Intersection with non-set on left" + except TypeError: + pass + try: + self.set & self.other + assert 0, "Intersection with non-set on right" + except TypeError: + pass + + def test_sym_difference_update(self): + try: + self.set ^= self.other + assert 0, "Symmetric difference update with non-set" + except TypeError: + pass + + def test_sym_difference(self): + try: + self.other ^ self.set + assert 0, "Symmetric difference with non-set on left" + except TypeError: + pass + try: + self.set ^ self.other + assert 0, "Symmetric difference with non-set on right" + except TypeError: + pass + + def test_difference_update(self): + try: + self.set -= self.other + assert 0, "Symmetric difference update with non-set" + except TypeError: + pass + + def test_difference(self): + try: + self.other - self.set + assert 0, "Symmetric difference with non-set on left" + except TypeError: + pass + try: + self.set - self.other + assert 0, "Symmetric difference with non-set on right" + except TypeError: + pass + +#------------------------------------------------------------------------------ + +class TestOnlySetsNumeric(TestOnlySetsInBinaryOps): + def setUp(self): + self.set = Set((1, 2, 3)) + self.other = 19 + +#------------------------------------------------------------------------------ + +class TestOnlySetsDict(TestOnlySetsInBinaryOps): + def setUp(self): + self.set = Set((1, 2, 3)) + self.other = {1:2, 3:4} + +#------------------------------------------------------------------------------ + +class TestOnlySetsOperator(TestOnlySetsInBinaryOps): + def setUp(self): + self.set = Set((1, 2, 3)) + self.other = operator.add + +#============================================================================== + +class TestCopying(unittest.TestCase): + + def test_copy(self): + dup = self.set.copy() + dup_list = list(dup); dup_list.sort() + set_list = list(self.set); set_list.sort() + assert len(dup_list) == len(set_list), "Unequal lengths after copy" + for i in range(len(dup_list)): + assert dup_list[i] is set_list[i], "Non-identical items after copy" + + def test_deep_copy(self): + dup = copy.deepcopy(self.set) + ##print type(dup), `dup` + dup_list = list(dup); dup_list.sort() + set_list = list(self.set); set_list.sort() + assert len(dup_list) == len(set_list), "Unequal lengths after deep copy" + for i in range(len(dup_list)): + assert dup_list[i] == set_list[i], "Unequal items after deep copy" + +#------------------------------------------------------------------------------ + +class TestCopyingEmpty(TestCopying): + def setUp(self): + self.set = Set() + +#------------------------------------------------------------------------------ + +class TestCopyingSingleton(TestCopying): + def setUp(self): + self.set = Set(["hello"]) + +#------------------------------------------------------------------------------ + +class TestCopyingTriple(TestCopying): + def setUp(self): + self.set = Set(["zero", 0, None]) + +#------------------------------------------------------------------------------ + +class TestCopyingTuple(TestCopying): + def setUp(self): + self.set = Set([(1, 2)]) + +#------------------------------------------------------------------------------ + +class TestCopyingNested(TestCopying): + def setUp(self): + self.set = Set([((1, 2), (3, 4))]) + +#============================================================================== + +def makeAllTests(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestBasicOpsEmpty)) + suite.addTest(unittest.makeSuite(TestBasicOpsSingleton)) + suite.addTest(unittest.makeSuite(TestBasicOpsTuple)) + suite.addTest(unittest.makeSuite(TestBasicOpsTriple)) + suite.addTest(unittest.makeSuite(TestBinaryOps)) + suite.addTest(unittest.makeSuite(TestUpdateOps)) + suite.addTest(unittest.makeSuite(TestMutate)) + suite.addTest(unittest.makeSuite(TestSubsetEqualEmpty)) + suite.addTest(unittest.makeSuite(TestSubsetEqualNonEmpty)) + suite.addTest(unittest.makeSuite(TestSubsetEmptyNonEmpty)) + suite.addTest(unittest.makeSuite(TestSubsetPartial)) + suite.addTest(unittest.makeSuite(TestSubsetNonOverlap)) + suite.addTest(unittest.makeSuite(TestOnlySetsNumeric)) + suite.addTest(unittest.makeSuite(TestOnlySetsDict)) + suite.addTest(unittest.makeSuite(TestOnlySetsOperator)) + suite.addTest(unittest.makeSuite(TestCopyingEmpty)) + suite.addTest(unittest.makeSuite(TestCopyingSingleton)) + suite.addTest(unittest.makeSuite(TestCopyingTriple)) + suite.addTest(unittest.makeSuite(TestCopyingTuple)) + suite.addTest(unittest.makeSuite(TestCopyingNested)) + return suite + +#------------------------------------------------------------------------------ + +def test_main(): + suite = makeAllTests() + test_support.run_suite(suite) + +if __name__ == "__main__": + test_main() |