diff options
author | Serhiy Storchaka <storchaka@gmail.com> | 2019-08-08 05:42:54 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-08-08 05:42:54 (GMT) |
commit | 662db125cddbca1db68116c547c290eb3943d98e (patch) | |
tree | 06151487dbe4493ef173dd8cc378f4b6cf5c0e4a /Lib | |
parent | 4c69be22df3852f17873a74d015528d9a8ae92d6 (diff) | |
download | cpython-662db125cddbca1db68116c547c290eb3943d98e.zip cpython-662db125cddbca1db68116c547c290eb3943d98e.tar.gz cpython-662db125cddbca1db68116c547c290eb3943d98e.tar.bz2 |
bpo-37685: Fixed __eq__, __lt__ etc implementations in some classes. (GH-14952)
They now return NotImplemented for unsupported type of the other operand.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/asyncio/events.py | 24 | ||||
-rw-r--r-- | Lib/distutils/tests/test_version.py | 16 | ||||
-rw-r--r-- | Lib/distutils/version.py | 4 | ||||
-rw-r--r-- | Lib/email/headerregistry.py | 8 | ||||
-rw-r--r-- | Lib/importlib/_bootstrap.py | 2 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_events.py | 23 | ||||
-rw-r--r-- | Lib/test/test_email/test_headerregistry.py | 19 | ||||
-rw-r--r-- | Lib/test/test_traceback.py | 16 | ||||
-rw-r--r-- | Lib/test/test_weakref.py | 9 | ||||
-rw-r--r-- | Lib/test/test_xmlrpc.py | 25 | ||||
-rw-r--r-- | Lib/tkinter/__init__.py | 2 | ||||
-rw-r--r-- | Lib/tkinter/font.py | 4 | ||||
-rw-r--r-- | Lib/tkinter/test/test_tkinter/test_font.py | 3 | ||||
-rw-r--r-- | Lib/tkinter/test/test_tkinter/test_variables.py | 13 | ||||
-rw-r--r-- | Lib/traceback.py | 4 | ||||
-rw-r--r-- | Lib/tracemalloc.py | 16 | ||||
-rw-r--r-- | Lib/unittest/mock.py | 4 | ||||
-rw-r--r-- | Lib/unittest/test/testmock/testmock.py | 8 | ||||
-rw-r--r-- | Lib/weakref.py | 4 | ||||
-rw-r--r-- | Lib/xmlrpc/client.py | 17 |
20 files changed, 178 insertions, 43 deletions
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index d381b1c..5fb5464 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -119,20 +119,24 @@ class TimerHandle(Handle): return hash(self._when) def __lt__(self, other): - return self._when < other._when + if isinstance(other, TimerHandle): + return self._when < other._when + return NotImplemented def __le__(self, other): - if self._when < other._when: - return True - return self.__eq__(other) + if isinstance(other, TimerHandle): + return self._when < other._when or self.__eq__(other) + return NotImplemented def __gt__(self, other): - return self._when > other._when + if isinstance(other, TimerHandle): + return self._when > other._when + return NotImplemented def __ge__(self, other): - if self._when > other._when: - return True - return self.__eq__(other) + if isinstance(other, TimerHandle): + return self._when > other._when or self.__eq__(other) + return NotImplemented def __eq__(self, other): if isinstance(other, TimerHandle): @@ -142,10 +146,6 @@ class TimerHandle(Handle): self._cancelled == other._cancelled) return NotImplemented - def __ne__(self, other): - equal = self.__eq__(other) - return NotImplemented if equal is NotImplemented else not equal - def cancel(self): if not self._cancelled: self._loop._timer_handle_cancelled(self) diff --git a/Lib/distutils/tests/test_version.py b/Lib/distutils/tests/test_version.py index 15f14c7..8671cd2 100644 --- a/Lib/distutils/tests/test_version.py +++ b/Lib/distutils/tests/test_version.py @@ -45,6 +45,14 @@ class VersionTestCase(unittest.TestCase): self.assertEqual(res, wanted, 'cmp(%s, %s) should be %s, got %s' % (v1, v2, wanted, res)) + res = StrictVersion(v1)._cmp(v2) + self.assertEqual(res, wanted, + 'cmp(%s, %s) should be %s, got %s' % + (v1, v2, wanted, res)) + res = StrictVersion(v1)._cmp(object()) + self.assertIs(res, NotImplemented, + 'cmp(%s, %s) should be NotImplemented, got %s' % + (v1, v2, res)) def test_cmp(self): @@ -63,6 +71,14 @@ class VersionTestCase(unittest.TestCase): self.assertEqual(res, wanted, 'cmp(%s, %s) should be %s, got %s' % (v1, v2, wanted, res)) + res = LooseVersion(v1)._cmp(v2) + self.assertEqual(res, wanted, + 'cmp(%s, %s) should be %s, got %s' % + (v1, v2, wanted, res)) + res = LooseVersion(v1)._cmp(object()) + self.assertIs(res, NotImplemented, + 'cmp(%s, %s) should be NotImplemented, got %s' % + (v1, v2, res)) def test_suite(): return unittest.makeSuite(VersionTestCase) diff --git a/Lib/distutils/version.py b/Lib/distutils/version.py index af14cc1..c33beba 100644 --- a/Lib/distutils/version.py +++ b/Lib/distutils/version.py @@ -166,6 +166,8 @@ class StrictVersion (Version): def _cmp (self, other): if isinstance(other, str): other = StrictVersion(other) + elif not isinstance(other, StrictVersion): + return NotImplemented if self.version != other.version: # numeric versions don't match @@ -331,6 +333,8 @@ class LooseVersion (Version): def _cmp (self, other): if isinstance(other, str): other = LooseVersion(other) + elif not isinstance(other, LooseVersion): + return NotImplemented if self.version == other.version: return 0 diff --git a/Lib/email/headerregistry.py b/Lib/email/headerregistry.py index 8d1a202..dcc960b 100644 --- a/Lib/email/headerregistry.py +++ b/Lib/email/headerregistry.py @@ -97,8 +97,8 @@ class Address: return self.addr_spec def __eq__(self, other): - if type(other) != type(self): - return False + if not isinstance(other, Address): + return NotImplemented return (self.display_name == other.display_name and self.username == other.username and self.domain == other.domain) @@ -150,8 +150,8 @@ class Group: return "{}:{};".format(disp, adrstr) def __eq__(self, other): - if type(other) != type(self): - return False + if not isinstance(other, Group): + return NotImplemented return (self.display_name == other.display_name and self.addresses == other.addresses) diff --git a/Lib/importlib/_bootstrap.py b/Lib/importlib/_bootstrap.py index 5e2f520..e17eeb6 100644 --- a/Lib/importlib/_bootstrap.py +++ b/Lib/importlib/_bootstrap.py @@ -371,7 +371,7 @@ class ModuleSpec: self.cached == other.cached and self.has_location == other.has_location) except AttributeError: - return False + return NotImplemented @property def cached(self): diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index e5ad72f..5bc1bc2 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -32,6 +32,7 @@ from asyncio import proactor_events from asyncio import selector_events from test.test_asyncio import utils as test_utils from test import support +from test.support import ALWAYS_EQ, LARGEST, SMALLEST def tearDownModule(): @@ -2364,6 +2365,28 @@ class TimerTests(unittest.TestCase): self.assertIs(NotImplemented, h1.__eq__(h3)) self.assertIs(NotImplemented, h1.__ne__(h3)) + with self.assertRaises(TypeError): + h1 < () + with self.assertRaises(TypeError): + h1 > () + with self.assertRaises(TypeError): + h1 <= () + with self.assertRaises(TypeError): + h1 >= () + self.assertFalse(h1 == ()) + self.assertTrue(h1 != ()) + + self.assertTrue(h1 == ALWAYS_EQ) + self.assertFalse(h1 != ALWAYS_EQ) + self.assertTrue(h1 < LARGEST) + self.assertFalse(h1 > LARGEST) + self.assertTrue(h1 <= LARGEST) + self.assertFalse(h1 >= LARGEST) + self.assertFalse(h1 < SMALLEST) + self.assertTrue(h1 > SMALLEST) + self.assertFalse(h1 <= SMALLEST) + self.assertTrue(h1 >= SMALLEST) + class AbstractEventLoopTests(unittest.TestCase): diff --git a/Lib/test/test_email/test_headerregistry.py b/Lib/test/test_email/test_headerregistry.py index 5d9b357..4758f4b 100644 --- a/Lib/test/test_email/test_headerregistry.py +++ b/Lib/test/test_email/test_headerregistry.py @@ -7,6 +7,7 @@ from email.message import Message from test.test_email import TestEmailBase, parameterize from email import headerregistry from email.headerregistry import Address, Group +from test.support import ALWAYS_EQ DITTO = object() @@ -1525,6 +1526,24 @@ class TestAddressAndGroup(TestEmailBase): self.assertEqual(m['to'], 'foo bar:;') self.assertEqual(m['to'].addresses, g.addresses) + def test_address_comparison(self): + a = Address('foo', 'bar', 'example.com') + self.assertEqual(Address('foo', 'bar', 'example.com'), a) + self.assertNotEqual(Address('baz', 'bar', 'example.com'), a) + self.assertNotEqual(Address('foo', 'baz', 'example.com'), a) + self.assertNotEqual(Address('foo', 'bar', 'baz'), a) + self.assertFalse(a == object()) + self.assertTrue(a == ALWAYS_EQ) + + def test_group_comparison(self): + a = Address('foo', 'bar', 'example.com') + g = Group('foo bar', [a]) + self.assertEqual(Group('foo bar', (a,)), g) + self.assertNotEqual(Group('baz', [a]), g) + self.assertNotEqual(Group('foo bar', []), g) + self.assertFalse(g == object()) + self.assertTrue(g == ALWAYS_EQ) + class TestFolding(TestHeaderBase): diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py index 96d85e2..72dc7af 100644 --- a/Lib/test/test_traceback.py +++ b/Lib/test/test_traceback.py @@ -7,7 +7,7 @@ import sys import unittest import re from test import support -from test.support import TESTFN, Error, captured_output, unlink, cpython_only +from test.support import TESTFN, Error, captured_output, unlink, cpython_only, ALWAYS_EQ from test.support.script_helper import assert_python_ok import textwrap @@ -887,6 +887,8 @@ class TestFrame(unittest.TestCase): # operator fallbacks to FrameSummary.__eq__. self.assertEqual(tuple(f), f) self.assertIsNone(f.locals) + self.assertNotEqual(f, object()) + self.assertEqual(f, ALWAYS_EQ) def test_lazy_lines(self): linecache.clearcache() @@ -1083,6 +1085,18 @@ class TestTracebackException(unittest.TestCase): self.assertEqual(exc_info[0], exc.exc_type) self.assertEqual(str(exc_info[1]), str(exc)) + def test_comparison(self): + try: + 1/0 + except Exception: + exc_info = sys.exc_info() + exc = traceback.TracebackException(*exc_info) + exc2 = traceback.TracebackException(*exc_info) + self.assertIsNot(exc, exc2) + self.assertEqual(exc, exc2) + self.assertNotEqual(exc, object()) + self.assertEqual(exc, ALWAYS_EQ) + def test_unhashable(self): class UnhashableException(Exception): def __eq__(self, other): diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py index ce5bbfc..41f78e7 100644 --- a/Lib/test/test_weakref.py +++ b/Lib/test/test_weakref.py @@ -11,7 +11,7 @@ import time import random from test import support -from test.support import script_helper +from test.support import script_helper, ALWAYS_EQ # Used in ReferencesTestCase.test_ref_created_during_del() . ref_from_del = None @@ -794,6 +794,10 @@ class ReferencesTestCase(TestBase): self.assertTrue(a != c) self.assertTrue(a == d) self.assertFalse(a != d) + self.assertFalse(a == x) + self.assertTrue(a != x) + self.assertTrue(a == ALWAYS_EQ) + self.assertFalse(a != ALWAYS_EQ) del x, y, z gc.collect() for r in a, b, c: @@ -1102,6 +1106,9 @@ class WeakMethodTestCase(unittest.TestCase): _ne(a, f) _ne(b, e) _ne(b, f) + # Compare with different types + _ne(a, x.some_method) + _eq(a, ALWAYS_EQ) del x, y, z gc.collect() # Dead WeakMethods compare by identity diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py index 52bacc1..e5c3496 100644 --- a/Lib/test/test_xmlrpc.py +++ b/Lib/test/test_xmlrpc.py @@ -15,6 +15,7 @@ import re import io import contextlib from test import support +from test.support import ALWAYS_EQ, LARGEST, SMALLEST try: import gzip @@ -530,14 +531,10 @@ class DateTimeTestCase(unittest.TestCase): # some other types dbytes = dstr.encode('ascii') dtuple = now.timetuple() - with self.assertRaises(TypeError): - dtime == 1970 - with self.assertRaises(TypeError): - dtime != dbytes - with self.assertRaises(TypeError): - dtime == bytearray(dbytes) - with self.assertRaises(TypeError): - dtime != dtuple + self.assertFalse(dtime == 1970) + self.assertTrue(dtime != dbytes) + self.assertFalse(dtime == bytearray(dbytes)) + self.assertTrue(dtime != dtuple) with self.assertRaises(TypeError): dtime < float(1970) with self.assertRaises(TypeError): @@ -547,6 +544,18 @@ class DateTimeTestCase(unittest.TestCase): with self.assertRaises(TypeError): dtime >= dtuple + self.assertTrue(dtime == ALWAYS_EQ) + self.assertFalse(dtime != ALWAYS_EQ) + self.assertTrue(dtime < LARGEST) + self.assertFalse(dtime > LARGEST) + self.assertTrue(dtime <= LARGEST) + self.assertFalse(dtime >= LARGEST) + self.assertFalse(dtime < SMALLEST) + self.assertTrue(dtime > SMALLEST) + self.assertFalse(dtime <= SMALLEST) + self.assertTrue(dtime >= SMALLEST) + + class BinaryTestCase(unittest.TestCase): # XXX What should str(Binary(b"\xff")) return? I'm chosing "\xff" diff --git a/Lib/tkinter/__init__.py b/Lib/tkinter/__init__.py index 9626a27..9258484 100644 --- a/Lib/tkinter/__init__.py +++ b/Lib/tkinter/__init__.py @@ -484,6 +484,8 @@ class Variable: Note: if the Variable's master matters to behavior also compare self._master == other._master """ + if not isinstance(other, Variable): + return NotImplemented return self.__class__.__name__ == other.__class__.__name__ \ and self._name == other._name diff --git a/Lib/tkinter/font.py b/Lib/tkinter/font.py index eeff454..15ad7ab 100644 --- a/Lib/tkinter/font.py +++ b/Lib/tkinter/font.py @@ -101,7 +101,9 @@ class Font: return self.name def __eq__(self, other): - return isinstance(other, Font) and self.name == other.name + if not isinstance(other, Font): + return NotImplemented + return self.name == other.name def __getitem__(self, key): return self.cget(key) diff --git a/Lib/tkinter/test/test_tkinter/test_font.py b/Lib/tkinter/test/test_tkinter/test_font.py index 97cd87c..a021ea3 100644 --- a/Lib/tkinter/test/test_tkinter/test_font.py +++ b/Lib/tkinter/test/test_tkinter/test_font.py @@ -1,7 +1,7 @@ import unittest import tkinter from tkinter import font -from test.support import requires, run_unittest, gc_collect +from test.support import requires, run_unittest, gc_collect, ALWAYS_EQ from tkinter.test.support import AbstractTkTest requires('gui') @@ -70,6 +70,7 @@ class FontTest(AbstractTkTest, unittest.TestCase): self.assertEqual(font1, font2) self.assertNotEqual(font1, font1.copy()) self.assertNotEqual(font1, 0) + self.assertEqual(font1, ALWAYS_EQ) def test_measure(self): self.assertIsInstance(self.font.measure('abc'), int) diff --git a/Lib/tkinter/test/test_tkinter/test_variables.py b/Lib/tkinter/test/test_tkinter/test_variables.py index 2eb1e12..08b7ded 100644 --- a/Lib/tkinter/test/test_tkinter/test_variables.py +++ b/Lib/tkinter/test/test_tkinter/test_variables.py @@ -2,6 +2,7 @@ import unittest import gc from tkinter import (Variable, StringVar, IntVar, DoubleVar, BooleanVar, Tcl, TclError) +from test.support import ALWAYS_EQ class Var(Variable): @@ -59,11 +60,17 @@ class TestVariable(TestBase): # values doesn't matter, only class and name are checked v1 = Variable(self.root, name="abc") v2 = Variable(self.root, name="abc") + self.assertIsNot(v1, v2) self.assertEqual(v1, v2) - v3 = Variable(self.root, name="abc") - v4 = StringVar(self.root, name="abc") - self.assertNotEqual(v3, v4) + v3 = StringVar(self.root, name="abc") + self.assertNotEqual(v1, v3) + + V = type('Variable', (), {}) + self.assertNotEqual(v1, V()) + + self.assertNotEqual(v1, object()) + self.assertEqual(v1, ALWAYS_EQ) def test_invalid_name(self): with self.assertRaises(TypeError): diff --git a/Lib/traceback.py b/Lib/traceback.py index ab35da9..7a4c8e1 100644 --- a/Lib/traceback.py +++ b/Lib/traceback.py @@ -538,7 +538,9 @@ class TracebackException: self.__cause__._load_lines() def __eq__(self, other): - return self.__dict__ == other.__dict__ + if isinstance(other, TracebackException): + return self.__dict__ == other.__dict__ + return NotImplemented def __str__(self): return self._str diff --git a/Lib/tracemalloc.py b/Lib/tracemalloc.py index 2c1ac3b..80b521c 100644 --- a/Lib/tracemalloc.py +++ b/Lib/tracemalloc.py @@ -43,6 +43,8 @@ class Statistic: return hash((self.traceback, self.size, self.count)) def __eq__(self, other): + if not isinstance(other, Statistic): + return NotImplemented return (self.traceback == other.traceback and self.size == other.size and self.count == other.count) @@ -84,6 +86,8 @@ class StatisticDiff: self.count, self.count_diff)) def __eq__(self, other): + if not isinstance(other, StatisticDiff): + return NotImplemented return (self.traceback == other.traceback and self.size == other.size and self.size_diff == other.size_diff @@ -153,9 +157,13 @@ class Frame: return self._frame[1] def __eq__(self, other): + if not isinstance(other, Frame): + return NotImplemented return (self._frame == other._frame) def __lt__(self, other): + if not isinstance(other, Frame): + return NotImplemented return (self._frame < other._frame) def __hash__(self): @@ -200,9 +208,13 @@ class Traceback(Sequence): return hash(self._frames) def __eq__(self, other): + if not isinstance(other, Traceback): + return NotImplemented return (self._frames == other._frames) def __lt__(self, other): + if not isinstance(other, Traceback): + return NotImplemented return (self._frames < other._frames) def __str__(self): @@ -271,6 +283,8 @@ class Trace: return Traceback(self._trace[2]) def __eq__(self, other): + if not isinstance(other, Trace): + return NotImplemented return (self._trace == other._trace) def __hash__(self): @@ -303,6 +317,8 @@ class _Traces(Sequence): return trace._trace in self._traces def __eq__(self, other): + if not isinstance(other, _Traces): + return NotImplemented return (self._traces == other._traces) def __repr__(self): diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index b3dc640..298b41e 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -2358,12 +2358,10 @@ class _Call(tuple): def __eq__(self, other): - if other is ANY: - return True try: len_other = len(other) except TypeError: - return False + return NotImplemented self_name = '' if len(self) == 2: diff --git a/Lib/unittest/test/testmock/testmock.py b/Lib/unittest/test/testmock/testmock.py index 18efd31..69b34e9 100644 --- a/Lib/unittest/test/testmock/testmock.py +++ b/Lib/unittest/test/testmock/testmock.py @@ -3,6 +3,7 @@ import re import sys import tempfile +from test.support import ALWAYS_EQ import unittest from unittest.test.testmock.support import is_instance from unittest import mock @@ -322,6 +323,8 @@ class MockTest(unittest.TestCase): self.assertFalse(mm != mock.ANY) self.assertTrue(mock.ANY == mm) self.assertFalse(mock.ANY != mm) + self.assertTrue(mm == ALWAYS_EQ) + self.assertFalse(mm != ALWAYS_EQ) call1 = mock.call(mock.MagicMock()) call2 = mock.call(mock.ANY) @@ -330,6 +333,11 @@ class MockTest(unittest.TestCase): self.assertTrue(call2 == call1) self.assertFalse(call2 != call1) + self.assertTrue(call1 == ALWAYS_EQ) + self.assertFalse(call1 != ALWAYS_EQ) + self.assertFalse(call1 == 1) + self.assertTrue(call1 != 1) + def test_assert_called_with(self): mock = Mock() diff --git a/Lib/weakref.py b/Lib/weakref.py index fa7559b..560deee 100644 --- a/Lib/weakref.py +++ b/Lib/weakref.py @@ -75,14 +75,14 @@ class WeakMethod(ref): if not self._alive or not other._alive: return self is other return ref.__eq__(self, other) and self._func_ref == other._func_ref - return False + return NotImplemented def __ne__(self, other): if isinstance(other, WeakMethod): if not self._alive or not other._alive: return self is not other return ref.__ne__(self, other) or self._func_ref != other._func_ref - return True + return NotImplemented __hash__ = ref.__hash__ diff --git a/Lib/xmlrpc/client.py b/Lib/xmlrpc/client.py index b987574..246ef27 100644 --- a/Lib/xmlrpc/client.py +++ b/Lib/xmlrpc/client.py @@ -313,31 +313,38 @@ class DateTime: s = self.timetuple() o = other.timetuple() else: - otype = (hasattr(other, "__class__") - and other.__class__.__name__ - or type(other)) - raise TypeError("Can't compare %s and %s" % - (self.__class__.__name__, otype)) + s = self + o = NotImplemented return s, o def __lt__(self, other): s, o = self.make_comparable(other) + if o is NotImplemented: + return NotImplemented return s < o def __le__(self, other): s, o = self.make_comparable(other) + if o is NotImplemented: + return NotImplemented return s <= o def __gt__(self, other): s, o = self.make_comparable(other) + if o is NotImplemented: + return NotImplemented return s > o def __ge__(self, other): s, o = self.make_comparable(other) + if o is NotImplemented: + return NotImplemented return s >= o def __eq__(self, other): s, o = self.make_comparable(other) + if o is NotImplemented: + return NotImplemented return s == o def timetuple(self): |