From 64c06e327d48150fc548cf18a4a7ae0b890e69fa Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 22 Nov 2007 00:55:51 +0000 Subject: Backport of _abccoll.py by Benjamin Arangueren, issue 1383. With some changes of my own thrown in (e.g. backport of r58107). --- Lib/_abcoll.py | 544 +++++++++++++++++++++++++++++++++++++++++++ Lib/abc.py | 2 +- Lib/collections.py | 6 + Lib/test/mapping_tests.py | 4 + Lib/test/regrtest.py | 23 +- Lib/test/test_abc.py | 14 ++ Lib/test/test_collections.py | 186 ++++++++++++++- Lib/test/test_dict.py | 4 + Objects/dictobject.c | 9 +- Objects/listobject.c | 10 +- Objects/object.c | 2 +- Objects/setobject.c | 9 +- Objects/typeobject.c | 60 +++-- 13 files changed, 822 insertions(+), 51 deletions(-) create mode 100644 Lib/_abcoll.py diff --git a/Lib/_abcoll.py b/Lib/_abcoll.py new file mode 100644 index 0000000..ac967b2 --- /dev/null +++ b/Lib/_abcoll.py @@ -0,0 +1,544 @@ +# Copyright 2007 Google, Inc. All Rights Reserved. +# Licensed to PSF under a Contributor Agreement. + +"""Abstract Base Classes (ABCs) for collections, according to PEP 3119. + +DON'T USE THIS MODULE DIRECTLY! The classes here should be imported +via collections; they are defined here only to alleviate certain +bootstrapping issues. Unit tests are in test_collections. +""" + +from abc import ABCMeta, abstractmethod + +__all__ = ["Hashable", "Iterable", "Iterator", + "Sized", "Container", "Callable", + "Set", "MutableSet", + "Mapping", "MutableMapping", + "MappingView", "KeysView", "ItemsView", "ValuesView", + "Sequence", "MutableSequence", + ] + +### ONE-TRICK PONIES ### + +class Hashable: + __metaclass__ = ABCMeta + + @abstractmethod + def __hash__(self): + return 0 + + @classmethod + def __subclasshook__(cls, C): + if cls is Hashable: + for B in C.__mro__: + if "__hash__" in B.__dict__: + if B.__dict__["__hash__"]: + return True + break + return NotImplemented + + +class Iterable: + __metaclass__ = ABCMeta + + @abstractmethod + def __iter__(self): + while False: + yield None + + @classmethod + def __subclasshook__(cls, C): + if cls is Iterable: + if any("__iter__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + +Iterable.register(str) + + +class Iterator: + __metaclass__ = ABCMeta + + @abstractmethod + def __next__(self): + raise StopIteration + + def __iter__(self): + return self + + @classmethod + def __subclasshook__(cls, C): + if cls is Iterator: + if any("next" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + + +class Sized: + __metaclass__ = ABCMeta + + @abstractmethod + def __len__(self): + return 0 + + @classmethod + def __subclasshook__(cls, C): + if cls is Sized: + if any("__len__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + + +class Container: + __metaclass__ = ABCMeta + + @abstractmethod + def __contains__(self, x): + return False + + @classmethod + def __subclasshook__(cls, C): + if cls is Container: + if any("__contains__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + + +class Callable: + __metaclass__ = ABCMeta + + @abstractmethod + def __contains__(self, x): + return False + + @classmethod + def __subclasshook__(cls, C): + if cls is Callable: + if any("__call__" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + + +### SETS ### + + +class Set: + __metaclass__ = ABCMeta + + """A set is a finite, iterable container. + + This class provides concrete generic implementations of all + methods except for __contains__, __iter__ and __len__. + + To override the comparisons (presumably for speed, as the + semantics are fixed), all you have to do is redefine __le__ and + then the other operations will automatically follow suit. + """ + + @abstractmethod + def __contains__(self, value): + return False + + @abstractmethod + def __iter__(self): + while False: + yield None + + @abstractmethod + def __len__(self): + return 0 + + def __le__(self, other): + if not isinstance(other, Set): + return NotImplemented + if len(self) > len(other): + return False + for elem in self: + if elem not in other: + return False + return True + + def __lt__(self, other): + if not isinstance(other, Set): + return NotImplemented + return len(self) < len(other) and self.__le__(other) + + def __eq__(self, other): + if not isinstance(other, Set): + return NotImplemented + return len(self) == len(other) and self.__le__(other) + + @classmethod + def _from_iterable(cls, it): + return frozenset(it) + + def __and__(self, other): + if not isinstance(other, Iterable): + return NotImplemented + return self._from_iterable(value for value in other if value in self) + + def __or__(self, other): + if not isinstance(other, Iterable): + return NotImplemented + return self._from_iterable(itertools.chain(self, other)) + + def __sub__(self, other): + if not isinstance(other, Set): + if not isinstance(other, Iterable): + return NotImplemented + other = self._from_iterable(other) + return self._from_iterable(value for value in self + if value not in other) + + def __xor__(self, other): + if not isinstance(other, Set): + if not isinstance(other, Iterable): + return NotImplemented + other = self._from_iterable(other) + return (self - other) | (other - self) + + def _hash(self): + """Compute the hash value of a set. + + Note that we don't define __hash__: not all sets are hashable. + But if you define a hashable set type, its __hash__ should + call this function. + + This must be compatible __eq__. + + All sets ought to compare equal if they contain the same + elements, regardless of how they are implemented, and + regardless of the order of the elements; so there's not much + freedom for __eq__ or __hash__. We match the algorithm used + by the built-in frozenset type. + """ + MAX = sys.maxint + MASK = 2 * MAX + 1 + n = len(self) + h = 1927868237 * (n + 1) + h &= MASK + for x in self: + hx = hash(x) + h ^= (hx ^ (hx << 16) ^ 89869747) * 3644798167 + h &= MASK + h = h * 69069 + 907133923 + h &= MASK + if h > MAX: + h -= MASK + 1 + if h == -1: + h = 590923713 + return h + +Set.register(frozenset) + + +class MutableSet(Set): + + @abstractmethod + def add(self, value): + """Return True if it was added, False if already there.""" + raise NotImplementedError + + @abstractmethod + def discard(self, value): + """Return True if it was deleted, False if not there.""" + raise NotImplementedError + + def pop(self): + """Return the popped value. Raise KeyError if empty.""" + it = iter(self) + try: + value = it.__next__() + except StopIteration: + raise KeyError + self.discard(value) + return value + + def toggle(self, value): + """Return True if it was added, False if deleted.""" + # XXX This implementation is not thread-safe + if value in self: + self.discard(value) + return False + else: + self.add(value) + return True + + def clear(self): + """This is slow (creates N new iterators!) but effective.""" + try: + while True: + self.pop() + except KeyError: + pass + + def __ior__(self, it): + for value in it: + self.add(value) + return self + + def __iand__(self, c): + for value in self: + if value not in c: + self.discard(value) + return self + + def __ixor__(self, it): + # This calls toggle(), so if that is overridded, we call the override + for value in it: + self.toggle(it) + return self + + def __isub__(self, it): + for value in it: + self.discard(value) + return self + +MutableSet.register(set) + + +### MAPPINGS ### + + +class Mapping: + __metaclass__ = ABCMeta + + @abstractmethod + def __getitem__(self, key): + raise KeyError + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key): + try: + self[key] + except KeyError: + return False + else: + return True + + @abstractmethod + def __len__(self): + return 0 + + @abstractmethod + def __iter__(self): + while False: + yield None + + def keys(self): + return KeysView(self) + + def items(self): + return ItemsView(self) + + def values(self): + return ValuesView(self) + + +class MappingView: + __metaclass__ = ABCMeta + + def __init__(self, mapping): + self._mapping = mapping + + def __len__(self): + return len(self._mapping) + + +class KeysView(MappingView, Set): + + def __contains__(self, key): + return key in self._mapping + + def __iter__(self): + for key in self._mapping: + yield key + +KeysView.register(type({}.keys())) + + +class ItemsView(MappingView, Set): + + def __contains__(self, item): + key, value = item + try: + v = self._mapping[key] + except KeyError: + return False + else: + return v == value + + def __iter__(self): + for key in self._mapping: + yield (key, self._mapping[key]) + +ItemsView.register(type({}.items())) + + +class ValuesView(MappingView): + + def __contains__(self, value): + for key in self._mapping: + if value == self._mapping[key]: + return True + return False + + def __iter__(self): + for key in self._mapping: + yield self._mapping[key] + +ValuesView.register(type({}.values())) + + +class MutableMapping(Mapping): + + @abstractmethod + def __setitem__(self, key, value): + raise KeyError + + @abstractmethod + def __delitem__(self, key): + raise KeyError + + __marker = object() + + def pop(self, key, default=__marker): + try: + value = self[key] + except KeyError: + if default is self.__marker: + raise + return default + else: + del self[key] + return value + + def popitem(self): + try: + key = next(iter(self)) + except StopIteration: + raise KeyError + value = self[key] + del self[key] + return key, value + + def clear(self): + try: + while True: + self.popitem() + except KeyError: + pass + + def update(self, other=(), **kwds): + if isinstance(other, Mapping): + for key in other: + self[key] = other[key] + elif hasattr(other, "keys"): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key, value in kwds.items(): + self[key] = value + +MutableMapping.register(dict) + + +### SEQUENCES ### + + +class Sequence: + __metaclass__ = ABCMeta + + """All the operations on a read-only sequence. + + Concrete subclasses must override __new__ or __init__, + __getitem__, and __len__. + """ + + @abstractmethod + def __getitem__(self, index): + raise IndexError + + @abstractmethod + def __len__(self): + return 0 + + def __iter__(self): + i = 0 + while True: + try: + v = self[i] + except IndexError: + break + yield v + i += 1 + + def __contains__(self, value): + for v in self: + if v == value: + return True + return False + + def __reversed__(self): + for i in reversed(range(len(self))): + yield self[i] + + def index(self, value): + for i, v in enumerate(self): + if v == value: + return i + raise ValueError + + def count(self, value): + return sum(1 for v in self if v == value) + +Sequence.register(tuple) +Sequence.register(basestring) +Sequence.register(buffer) + + +class MutableSequence(Sequence): + + @abstractmethod + def __setitem__(self, index, value): + raise IndexError + + @abstractmethod + def __delitem__(self, index): + raise IndexError + + @abstractmethod + def insert(self, index, value): + raise IndexError + + def append(self, value): + self.insert(len(self), value) + + def reverse(self): + n = len(self) + for i in range(n//2): + self[i], self[n-i-1] = self[n-i-1], self[i] + + def extend(self, values): + for v in values: + self.append(v) + + def pop(self, index=-1): + v = self[index] + del self[index] + return v + + def remove(self, value): + del self[self.index(value)] + + def __iadd__(self, values): + self.extend(values) + +MutableSequence.register(list) diff --git a/Lib/abc.py b/Lib/abc.py index d80fabd..6857029 100644 --- a/Lib/abc.py +++ b/Lib/abc.py @@ -69,7 +69,7 @@ class _Abstract(object): if (args or kwds) and cls.__init__ is object.__init__: raise TypeError("Can't pass arguments to __new__ " "without overriding __init__") - return object.__new__(cls) + return super(_Abstract, cls).__new__(cls) @classmethod def __subclasshook__(cls, subclass): diff --git a/Lib/collections.py b/Lib/collections.py index 7381a3a..356d961 100644 --- a/Lib/collections.py +++ b/Lib/collections.py @@ -5,6 +5,12 @@ from operator import itemgetter as _itemgetter from keyword import iskeyword as _iskeyword import sys as _sys +# For bootstrapping reasons, the collection ABCs are defined in _abcoll.py. +# They should however be considered an integral part of collections.py. +from _abcoll import * +import _abcoll +__all__ += _abcoll.__all__ + def namedtuple(typename, field_names, verbose=False): """Returns a new subclass of tuple with named fields. diff --git a/Lib/test/mapping_tests.py b/Lib/test/mapping_tests.py index 4b0f797..c6857ab 100644 --- a/Lib/test/mapping_tests.py +++ b/Lib/test/mapping_tests.py @@ -557,6 +557,8 @@ class TestHashMappingProtocol(TestMappingProtocol): class BadEq(object): def __eq__(self, other): raise Exc() + def __hash__(self): + return 24 d = self._empty_mapping() d[BadEq()] = 42 @@ -642,6 +644,8 @@ class TestHashMappingProtocol(TestMappingProtocol): class BadCmp(object): def __eq__(self, other): raise Exc() + def __hash__(self): + return 42 d1 = self._full_mapping({BadCmp(): 1}) d2 = self._full_mapping({1: 1}) diff --git a/Lib/test/regrtest.py b/Lib/test/regrtest.py index 7973b23..9694e2a 100755 --- a/Lib/test/regrtest.py +++ b/Lib/test/regrtest.py @@ -648,7 +648,7 @@ def cleanup_test_droppings(testname, verbose): def dash_R(the_module, test, indirect_test, huntrleaks): # This code is hackish and inelegant, but it seems to do the job. - import copy_reg + import copy_reg, _abcoll if not hasattr(sys, 'gettotalrefcount'): raise Exception("Tracking reference leaks requires a debug build " @@ -658,6 +658,12 @@ def dash_R(the_module, test, indirect_test, huntrleaks): fs = warnings.filters[:] ps = copy_reg.dispatch_table.copy() pic = sys.path_importer_cache.copy() + abcs = {} + for abc in [getattr(_abcoll, a) for a in _abcoll.__all__]: + for obj in abc.__subclasses__() + [abc]: + abcs[obj] = obj._abc_registry.copy() + + print >> sys.stderr, abcs if indirect_test: def run_the_test(): @@ -671,12 +677,12 @@ def dash_R(the_module, test, indirect_test, huntrleaks): repcount = nwarmup + ntracked print >> sys.stderr, "beginning", repcount, "repetitions" print >> sys.stderr, ("1234567890"*(repcount//10 + 1))[:repcount] - dash_R_cleanup(fs, ps, pic) + dash_R_cleanup(fs, ps, pic, abcs) for i in range(repcount): rc = sys.gettotalrefcount() run_the_test() sys.stderr.write('.') - dash_R_cleanup(fs, ps, pic) + dash_R_cleanup(fs, ps, pic, abcs) if i >= nwarmup: deltas.append(sys.gettotalrefcount() - rc - 2) print >> sys.stderr @@ -687,11 +693,11 @@ def dash_R(the_module, test, indirect_test, huntrleaks): print >> refrep, msg refrep.close() -def dash_R_cleanup(fs, ps, pic): +def dash_R_cleanup(fs, ps, pic, abcs): import gc, copy_reg import _strptime, linecache, dircache import urlparse, urllib, urllib2, mimetypes, doctest - import struct, filecmp + import struct, filecmp, _abcoll from distutils.dir_util import _path_created # Restore some original values. @@ -701,6 +707,13 @@ def dash_R_cleanup(fs, ps, pic): sys.path_importer_cache.clear() sys.path_importer_cache.update(pic) + # Clear ABC registries, restoring previously saved ABC registries. + for abc in [getattr(_abcoll, a) for a in _abcoll.__all__]: + for obj in abc.__subclasses__() + [abc]: + obj._abc_registry = abcs.get(obj, {}).copy() + obj._abc_cache.clear() + obj._abc_negative_cache.clear() + # Clear assorted module caches. _path_created.clear() re.purge() diff --git a/Lib/test/test_abc.py b/Lib/test/test_abc.py index 3fd9bde..c760356 100644 --- a/Lib/test/test_abc.py +++ b/Lib/test/test_abc.py @@ -133,6 +133,20 @@ class TestABC(unittest.TestCase): self.failUnless(issubclass(MyInt, A)) self.failUnless(isinstance(42, A)) + def test_all_new_methods_are_called(self): + class A: + __metaclass__ = abc.ABCMeta + class B: + counter = 0 + def __new__(cls): + B.counter += 1 + return super(B, cls).__new__(cls) + class C(A, B): + pass + self.assertEqual(B.counter, 0) + C() + self.assertEqual(B.counter, 1) + def test_main(): test_support.run_unittest(TestABC) diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 7c5b2dc..52bae9a 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -1,6 +1,12 @@ import unittest from test import test_support from collections import namedtuple +from collections import Hashable, Iterable, Iterator +from collections import Sized, Container, Callable +from collections import Set, MutableSet +from collections import Mapping, MutableMapping +from collections import Sequence, MutableSequence + class TestNamedTuple(unittest.TestCase): @@ -86,9 +92,187 @@ class TestNamedTuple(unittest.TestCase): Dot = namedtuple('Dot', 'd') self.assertEqual(Dot(1), (1,)) + +class TestOneTrickPonyABCs(unittest.TestCase): + + def test_Hashable(self): + # Check some non-hashables + non_samples = [list(), set(), dict()] + for x in non_samples: + self.failIf(isinstance(x, Hashable), repr(x)) + self.failIf(issubclass(type(x), Hashable), repr(type(x))) + # Check some hashables + samples = [None, + int(), float(), complex(), + str(), + tuple(), frozenset(), + int, list, object, type, + ] + for x in samples: + self.failUnless(isinstance(x, Hashable), repr(x)) + self.failUnless(issubclass(type(x), Hashable), repr(type(x))) + self.assertRaises(TypeError, Hashable) + # Check direct subclassing + class H(Hashable): + def __hash__(self): + return super(H, self).__hash__() + self.assertEqual(hash(H()), 0) + self.failIf(issubclass(int, H)) + + def test_Iterable(self): + # Check some non-iterables + non_samples = [None, 42, 3.14, 1j] + for x in non_samples: + self.failIf(isinstance(x, Iterable), repr(x)) + self.failIf(issubclass(type(x), Iterable), repr(type(x))) + # Check some iterables + samples = [str(), + tuple(), list(), set(), frozenset(), dict(), + dict().keys(), dict().items(), dict().values(), + (lambda: (yield))(), + (x for x in []), + ] + for x in samples: + self.failUnless(isinstance(x, Iterable), repr(x)) + self.failUnless(issubclass(type(x), Iterable), repr(type(x))) + # Check direct subclassing + class I(Iterable): + def __iter__(self): + return super(I, self).__iter__() + self.assertEqual(list(I()), []) + self.failIf(issubclass(str, I)) + + def test_Iterator(self): + non_samples = [None, 42, 3.14, 1j, "".encode('ascii'), "", (), [], + {}, set()] + for x in non_samples: + self.failIf(isinstance(x, Iterator), repr(x)) + self.failIf(issubclass(type(x), Iterator), repr(type(x))) + samples = [iter(str()), + iter(tuple()), iter(list()), iter(dict()), + iter(set()), iter(frozenset()), + iter(dict().keys()), iter(dict().items()), + iter(dict().values()), + (lambda: (yield))(), + (x for x in []), + ] + for x in samples: + self.failUnless(isinstance(x, Iterator), repr(x)) + self.failUnless(issubclass(type(x), Iterator), repr(type(x))) + + def test_Sized(self): + non_samples = [None, 42, 3.14, 1j, + (lambda: (yield))(), + (x for x in []), + ] + for x in non_samples: + self.failIf(isinstance(x, Sized), repr(x)) + self.failIf(issubclass(type(x), Sized), repr(type(x))) + samples = [str(), + tuple(), list(), set(), frozenset(), dict(), + dict().keys(), dict().items(), dict().values(), + ] + for x in samples: + self.failUnless(isinstance(x, Sized), repr(x)) + self.failUnless(issubclass(type(x), Sized), repr(type(x))) + + def test_Container(self): + non_samples = [None, 42, 3.14, 1j, + (lambda: (yield))(), + (x for x in []), + ] + for x in non_samples: + self.failIf(isinstance(x, Container), repr(x)) + self.failIf(issubclass(type(x), Container), repr(type(x))) + samples = [str(), + tuple(), list(), set(), frozenset(), dict(), + dict().keys(), dict().items(), + ] + for x in samples: + self.failUnless(isinstance(x, Container), repr(x)) + self.failUnless(issubclass(type(x), Container), repr(type(x))) + + def test_Callable(self): + non_samples = [None, 42, 3.14, 1j, + "", "".encode('ascii'), (), [], {}, set(), + (lambda: (yield))(), + (x for x in []), + ] + for x in non_samples: + self.failIf(isinstance(x, Callable), repr(x)) + self.failIf(issubclass(type(x), Callable), repr(type(x))) + samples = [lambda: None, + type, int, object, + len, + list.append, [].append, + ] + for x in samples: + self.failUnless(isinstance(x, Callable), repr(x)) + self.failUnless(issubclass(type(x), Callable), repr(type(x))) + + def test_direct_subclassing(self): + for B in Hashable, Iterable, Iterator, Sized, Container, Callable: + class C(B): + pass + self.failUnless(issubclass(C, B)) + self.failIf(issubclass(int, C)) + + def test_registration(self): + for B in Hashable, Iterable, Iterator, Sized, Container, Callable: + class C: + __metaclass__ = type + __hash__ = None # Make sure it isn't hashable by default + self.failIf(issubclass(C, B), B.__name__) + B.register(C) + self.failUnless(issubclass(C, B)) + + +class TestCollectionABCs(unittest.TestCase): + + # XXX For now, we only test some virtual inheritance properties. + # We should also test the proper behavior of the collection ABCs + # as real base classes or mix-in classes. + + def test_Set(self): + for sample in [set, frozenset]: + self.failUnless(isinstance(sample(), Set)) + self.failUnless(issubclass(sample, Set)) + + def test_MutableSet(self): + self.failUnless(isinstance(set(), MutableSet)) + self.failUnless(issubclass(set, MutableSet)) + self.failIf(isinstance(frozenset(), MutableSet)) + self.failIf(issubclass(frozenset, MutableSet)) + + def test_Mapping(self): + for sample in [dict]: + self.failUnless(isinstance(sample(), Mapping)) + self.failUnless(issubclass(sample, Mapping)) + + def test_MutableMapping(self): + for sample in [dict]: + self.failUnless(isinstance(sample(), MutableMapping)) + self.failUnless(issubclass(sample, MutableMapping)) + + def test_Sequence(self): + for sample in [tuple, list, str]: + self.failUnless(isinstance(sample(), Sequence)) + self.failUnless(issubclass(sample, Sequence)) + self.failUnless(issubclass(basestring, Sequence)) + + def test_MutableSequence(self): + for sample in [tuple, str]: + self.failIf(isinstance(sample(), MutableSequence)) + self.failIf(issubclass(sample, MutableSequence)) + for sample in [list]: + self.failUnless(isinstance(sample(), MutableSequence)) + self.failUnless(issubclass(sample, MutableSequence)) + self.failIf(issubclass(basestring, MutableSequence)) + + def test_main(verbose=None): import collections as CollectionsModule - test_classes = [TestNamedTuple] + test_classes = [TestNamedTuple, TestOneTrickPonyABCs, TestCollectionABCs] test_support.run_unittest(*test_classes) test_support.run_doctest(CollectionsModule, verbose) diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py index 9f049ad..342e775 100644 --- a/Lib/test/test_dict.py +++ b/Lib/test/test_dict.py @@ -86,6 +86,8 @@ class DictTest(unittest.TestCase): class BadEq(object): def __eq__(self, other): raise Exc() + def __hash__(self): + return 24 d = {} d[BadEq()] = 42 @@ -397,6 +399,8 @@ class DictTest(unittest.TestCase): class BadCmp(object): def __eq__(self, other): raise Exc() + def __hash__(self): + return 42 d1 = {BadCmp(): 1} d2 = {1: 1} diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 978071b..bfb891c 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -2127,13 +2127,6 @@ dict_init(PyObject *self, PyObject *args, PyObject *kwds) return dict_update_common(self, args, kwds, "dict"); } -static long -dict_nohash(PyObject *self) -{ - PyErr_SetString(PyExc_TypeError, "dict objects are unhashable"); - return -1; -} - static PyObject * dict_iter(PyDictObject *dict) { @@ -2165,7 +2158,7 @@ PyTypeObject PyDict_Type = { 0, /* tp_as_number */ &dict_as_sequence, /* tp_as_sequence */ &dict_as_mapping, /* tp_as_mapping */ - dict_nohash, /* tp_hash */ + 0, /* tp_hash */ 0, /* tp_call */ 0, /* tp_str */ PyObject_GenericGetAttr, /* tp_getattro */ diff --git a/Objects/listobject.c b/Objects/listobject.c index fb5ce82..deb3ca5 100644 --- a/Objects/listobject.c +++ b/Objects/listobject.c @@ -2393,13 +2393,6 @@ list_init(PyListObject *self, PyObject *args, PyObject *kw) return 0; } -static long -list_nohash(PyObject *self) -{ - PyErr_SetString(PyExc_TypeError, "list objects are unhashable"); - return -1; -} - static PyObject *list_iter(PyObject *seq); static PyObject *list_reversed(PyListObject* seq, PyObject* unused); @@ -2694,7 +2687,7 @@ PyTypeObject PyList_Type = { 0, /* tp_as_number */ &list_as_sequence, /* tp_as_sequence */ &list_as_mapping, /* tp_as_mapping */ - list_nohash, /* tp_hash */ + 0, /* tp_hash */ 0, /* tp_call */ 0, /* tp_str */ PyObject_GenericGetAttr, /* tp_getattro */ @@ -2959,4 +2952,3 @@ listreviter_len(listreviterobject *it) return 0; return len; } - diff --git a/Objects/object.c b/Objects/object.c index e75a03d..de385ea 100644 --- a/Objects/object.c +++ b/Objects/object.c @@ -1902,7 +1902,7 @@ static PyTypeObject PyNone_Type = { 0, /*tp_as_number*/ 0, /*tp_as_sequence*/ 0, /*tp_as_mapping*/ - 0, /*tp_hash */ + (hashfunc)_Py_HashPointer, /*tp_hash */ }; PyObject _Py_NoneStruct = { diff --git a/Objects/setobject.c b/Objects/setobject.c index 3cbcd9e..b049d09 100644 --- a/Objects/setobject.c +++ b/Objects/setobject.c @@ -789,13 +789,6 @@ frozenset_hash(PyObject *self) return hash; } -static long -set_nohash(PyObject *self) -{ - PyErr_SetString(PyExc_TypeError, "set objects are unhashable"); - return -1; -} - /***** Set iterator type ***********************************************/ typedef struct { @@ -2012,7 +2005,7 @@ PyTypeObject PySet_Type = { &set_as_number, /* tp_as_number */ &set_as_sequence, /* tp_as_sequence */ 0, /* tp_as_mapping */ - set_nohash, /* tp_hash */ + 0, /* tp_hash */ 0, /* tp_call */ 0, /* tp_str */ PyObject_GenericGetAttr, /* tp_getattro */ diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 59dec4a..1a221c8 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -2590,12 +2590,6 @@ object_str(PyObject *self) return f(self); } -static long -object_hash(PyObject *self) -{ - return _Py_HashPointer(self); -} - static PyObject * object_get_class(PyObject *self, void *closure) { @@ -3030,7 +3024,7 @@ PyTypeObject PyBaseObject_Type = { 0, /* tp_as_number */ 0, /* tp_as_sequence */ 0, /* tp_as_mapping */ - object_hash, /* tp_hash */ + (hashfunc)_Py_HashPointer, /* tp_hash */ 0, /* tp_call */ object_str, /* tp_str */ PyObject_GenericGetAttr, /* tp_getattro */ @@ -3236,6 +3230,33 @@ inherit_special(PyTypeObject *type, PyTypeObject *base) type->tp_flags |= Py_TPFLAGS_DICT_SUBCLASS; } +/* Map rich comparison operators to their __xx__ namesakes */ +static char *name_op[] = { + "__lt__", + "__le__", + "__eq__", + "__ne__", + "__gt__", + "__ge__", + "__cmp__", + /* These are only for overrides_hash(): */ + "__hash__", +}; + +static int +overrides_hash(PyTypeObject *type) +{ + int i; + PyObject *dict = type->tp_dict; + + assert(dict != NULL); + for (i = 0; i < 8; i++) { + if (PyDict_GetItemString(dict, name_op[i]) != NULL) + return 1; + } + return 0; +} + static void inherit_slots(PyTypeObject *type, PyTypeObject *base) { @@ -3367,7 +3388,8 @@ inherit_slots(PyTypeObject *type, PyTypeObject *base) if (type->tp_flags & base->tp_flags & Py_TPFLAGS_HAVE_RICHCOMPARE) { if (type->tp_compare == NULL && type->tp_richcompare == NULL && - type->tp_hash == NULL) + type->tp_hash == NULL && + !overrides_hash(type)) { type->tp_compare = base->tp_compare; type->tp_richcompare = base->tp_richcompare; @@ -3548,6 +3570,18 @@ PyType_Ready(PyTypeObject *type) } } + /* Hack for tp_hash and __hash__. + If after all that, tp_hash is still NULL, and __hash__ is not in + tp_dict, set tp_dict['__hash__'] equal to None. + This signals that __hash__ is not inherited. + */ + if (type->tp_hash == NULL && + PyDict_GetItemString(type->tp_dict, "__hash__") == NULL && + PyDict_SetItemString(type->tp_dict, "__hash__", Py_None) < 0) + { + goto error; + } + /* Some more special stuff */ base = type->tp_base; if (base != NULL) { @@ -4937,16 +4971,6 @@ slot_tp_setattro(PyObject *self, PyObject *name, PyObject *value) return 0; } -/* Map rich comparison operators to their __xx__ namesakes */ -static char *name_op[] = { - "__lt__", - "__le__", - "__eq__", - "__ne__", - "__gt__", - "__ge__", -}; - static PyObject * half_richcompare(PyObject *self, PyObject *other, int op) { -- cgit v0.12