summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2019-08-08 05:42:54 (GMT)
committerGitHub <noreply@github.com>2019-08-08 05:42:54 (GMT)
commit662db125cddbca1db68116c547c290eb3943d98e (patch)
tree06151487dbe4493ef173dd8cc378f4b6cf5c0e4a /Lib
parent4c69be22df3852f17873a74d015528d9a8ae92d6 (diff)
downloadcpython-662db125cddbca1db68116c547c290eb3943d98e.zip
cpython-662db125cddbca1db68116c547c290eb3943d98e.tar.gz
cpython-662db125cddbca1db68116c547c290eb3943d98e.tar.bz2
bpo-37685: Fixed __eq__, __lt__ etc implementations in some classes. (GH-14952)
They now return NotImplemented for unsupported type of the other operand.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/events.py24
-rw-r--r--Lib/distutils/tests/test_version.py16
-rw-r--r--Lib/distutils/version.py4
-rw-r--r--Lib/email/headerregistry.py8
-rw-r--r--Lib/importlib/_bootstrap.py2
-rw-r--r--Lib/test/test_asyncio/test_events.py23
-rw-r--r--Lib/test/test_email/test_headerregistry.py19
-rw-r--r--Lib/test/test_traceback.py16
-rw-r--r--Lib/test/test_weakref.py9
-rw-r--r--Lib/test/test_xmlrpc.py25
-rw-r--r--Lib/tkinter/__init__.py2
-rw-r--r--Lib/tkinter/font.py4
-rw-r--r--Lib/tkinter/test/test_tkinter/test_font.py3
-rw-r--r--Lib/tkinter/test/test_tkinter/test_variables.py13
-rw-r--r--Lib/traceback.py4
-rw-r--r--Lib/tracemalloc.py16
-rw-r--r--Lib/unittest/mock.py4
-rw-r--r--Lib/unittest/test/testmock/testmock.py8
-rw-r--r--Lib/weakref.py4
-rw-r--r--Lib/xmlrpc/client.py17
20 files changed, 178 insertions, 43 deletions
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index d381b1c..5fb5464 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -119,20 +119,24 @@ class TimerHandle(Handle):
return hash(self._when)
def __lt__(self, other):
- return self._when < other._when
+ if isinstance(other, TimerHandle):
+ return self._when < other._when
+ return NotImplemented
def __le__(self, other):
- if self._when < other._when:
- return True
- return self.__eq__(other)
+ if isinstance(other, TimerHandle):
+ return self._when < other._when or self.__eq__(other)
+ return NotImplemented
def __gt__(self, other):
- return self._when > other._when
+ if isinstance(other, TimerHandle):
+ return self._when > other._when
+ return NotImplemented
def __ge__(self, other):
- if self._when > other._when:
- return True
- return self.__eq__(other)
+ if isinstance(other, TimerHandle):
+ return self._when > other._when or self.__eq__(other)
+ return NotImplemented
def __eq__(self, other):
if isinstance(other, TimerHandle):
@@ -142,10 +146,6 @@ class TimerHandle(Handle):
self._cancelled == other._cancelled)
return NotImplemented
- def __ne__(self, other):
- equal = self.__eq__(other)
- return NotImplemented if equal is NotImplemented else not equal
-
def cancel(self):
if not self._cancelled:
self._loop._timer_handle_cancelled(self)
diff --git a/Lib/distutils/tests/test_version.py b/Lib/distutils/tests/test_version.py
index 15f14c7..8671cd2 100644
--- a/Lib/distutils/tests/test_version.py
+++ b/Lib/distutils/tests/test_version.py
@@ -45,6 +45,14 @@ class VersionTestCase(unittest.TestCase):
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
+ res = StrictVersion(v1)._cmp(v2)
+ self.assertEqual(res, wanted,
+ 'cmp(%s, %s) should be %s, got %s' %
+ (v1, v2, wanted, res))
+ res = StrictVersion(v1)._cmp(object())
+ self.assertIs(res, NotImplemented,
+ 'cmp(%s, %s) should be NotImplemented, got %s' %
+ (v1, v2, res))
def test_cmp(self):
@@ -63,6 +71,14 @@ class VersionTestCase(unittest.TestCase):
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
+ res = LooseVersion(v1)._cmp(v2)
+ self.assertEqual(res, wanted,
+ 'cmp(%s, %s) should be %s, got %s' %
+ (v1, v2, wanted, res))
+ res = LooseVersion(v1)._cmp(object())
+ self.assertIs(res, NotImplemented,
+ 'cmp(%s, %s) should be NotImplemented, got %s' %
+ (v1, v2, res))
def test_suite():
return unittest.makeSuite(VersionTestCase)
diff --git a/Lib/distutils/version.py b/Lib/distutils/version.py
index af14cc1..c33beba 100644
--- a/Lib/distutils/version.py
+++ b/Lib/distutils/version.py
@@ -166,6 +166,8 @@ class StrictVersion (Version):
def _cmp (self, other):
if isinstance(other, str):
other = StrictVersion(other)
+ elif not isinstance(other, StrictVersion):
+ return NotImplemented
if self.version != other.version:
# numeric versions don't match
@@ -331,6 +333,8 @@ class LooseVersion (Version):
def _cmp (self, other):
if isinstance(other, str):
other = LooseVersion(other)
+ elif not isinstance(other, LooseVersion):
+ return NotImplemented
if self.version == other.version:
return 0
diff --git a/Lib/email/headerregistry.py b/Lib/email/headerregistry.py
index 8d1a202..dcc960b 100644
--- a/Lib/email/headerregistry.py
+++ b/Lib/email/headerregistry.py
@@ -97,8 +97,8 @@ class Address:
return self.addr_spec
def __eq__(self, other):
- if type(other) != type(self):
- return False
+ if not isinstance(other, Address):
+ return NotImplemented
return (self.display_name == other.display_name and
self.username == other.username and
self.domain == other.domain)
@@ -150,8 +150,8 @@ class Group:
return "{}:{};".format(disp, adrstr)
def __eq__(self, other):
- if type(other) != type(self):
- return False
+ if not isinstance(other, Group):
+ return NotImplemented
return (self.display_name == other.display_name and
self.addresses == other.addresses)
diff --git a/Lib/importlib/_bootstrap.py b/Lib/importlib/_bootstrap.py
index 5e2f520..e17eeb6 100644
--- a/Lib/importlib/_bootstrap.py
+++ b/Lib/importlib/_bootstrap.py
@@ -371,7 +371,7 @@ class ModuleSpec:
self.cached == other.cached and
self.has_location == other.has_location)
except AttributeError:
- return False
+ return NotImplemented
@property
def cached(self):
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
index e5ad72f..5bc1bc2 100644
--- a/Lib/test/test_asyncio/test_events.py
+++ b/Lib/test/test_asyncio/test_events.py
@@ -32,6 +32,7 @@ from asyncio import proactor_events
from asyncio import selector_events
from test.test_asyncio import utils as test_utils
from test import support
+from test.support import ALWAYS_EQ, LARGEST, SMALLEST
def tearDownModule():
@@ -2364,6 +2365,28 @@ class TimerTests(unittest.TestCase):
self.assertIs(NotImplemented, h1.__eq__(h3))
self.assertIs(NotImplemented, h1.__ne__(h3))
+ with self.assertRaises(TypeError):
+ h1 < ()
+ with self.assertRaises(TypeError):
+ h1 > ()
+ with self.assertRaises(TypeError):
+ h1 <= ()
+ with self.assertRaises(TypeError):
+ h1 >= ()
+ self.assertFalse(h1 == ())
+ self.assertTrue(h1 != ())
+
+ self.assertTrue(h1 == ALWAYS_EQ)
+ self.assertFalse(h1 != ALWAYS_EQ)
+ self.assertTrue(h1 < LARGEST)
+ self.assertFalse(h1 > LARGEST)
+ self.assertTrue(h1 <= LARGEST)
+ self.assertFalse(h1 >= LARGEST)
+ self.assertFalse(h1 < SMALLEST)
+ self.assertTrue(h1 > SMALLEST)
+ self.assertFalse(h1 <= SMALLEST)
+ self.assertTrue(h1 >= SMALLEST)
+
class AbstractEventLoopTests(unittest.TestCase):
diff --git a/Lib/test/test_email/test_headerregistry.py b/Lib/test/test_email/test_headerregistry.py
index 5d9b357..4758f4b 100644
--- a/Lib/test/test_email/test_headerregistry.py
+++ b/Lib/test/test_email/test_headerregistry.py
@@ -7,6 +7,7 @@ from email.message import Message
from test.test_email import TestEmailBase, parameterize
from email import headerregistry
from email.headerregistry import Address, Group
+from test.support import ALWAYS_EQ
DITTO = object()
@@ -1525,6 +1526,24 @@ class TestAddressAndGroup(TestEmailBase):
self.assertEqual(m['to'], 'foo bar:;')
self.assertEqual(m['to'].addresses, g.addresses)
+ def test_address_comparison(self):
+ a = Address('foo', 'bar', 'example.com')
+ self.assertEqual(Address('foo', 'bar', 'example.com'), a)
+ self.assertNotEqual(Address('baz', 'bar', 'example.com'), a)
+ self.assertNotEqual(Address('foo', 'baz', 'example.com'), a)
+ self.assertNotEqual(Address('foo', 'bar', 'baz'), a)
+ self.assertFalse(a == object())
+ self.assertTrue(a == ALWAYS_EQ)
+
+ def test_group_comparison(self):
+ a = Address('foo', 'bar', 'example.com')
+ g = Group('foo bar', [a])
+ self.assertEqual(Group('foo bar', (a,)), g)
+ self.assertNotEqual(Group('baz', [a]), g)
+ self.assertNotEqual(Group('foo bar', []), g)
+ self.assertFalse(g == object())
+ self.assertTrue(g == ALWAYS_EQ)
+
class TestFolding(TestHeaderBase):
diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py
index 96d85e2..72dc7af 100644
--- a/Lib/test/test_traceback.py
+++ b/Lib/test/test_traceback.py
@@ -7,7 +7,7 @@ import sys
import unittest
import re
from test import support
-from test.support import TESTFN, Error, captured_output, unlink, cpython_only
+from test.support import TESTFN, Error, captured_output, unlink, cpython_only, ALWAYS_EQ
from test.support.script_helper import assert_python_ok
import textwrap
@@ -887,6 +887,8 @@ class TestFrame(unittest.TestCase):
# operator fallbacks to FrameSummary.__eq__.
self.assertEqual(tuple(f), f)
self.assertIsNone(f.locals)
+ self.assertNotEqual(f, object())
+ self.assertEqual(f, ALWAYS_EQ)
def test_lazy_lines(self):
linecache.clearcache()
@@ -1083,6 +1085,18 @@ class TestTracebackException(unittest.TestCase):
self.assertEqual(exc_info[0], exc.exc_type)
self.assertEqual(str(exc_info[1]), str(exc))
+ def test_comparison(self):
+ try:
+ 1/0
+ except Exception:
+ exc_info = sys.exc_info()
+ exc = traceback.TracebackException(*exc_info)
+ exc2 = traceback.TracebackException(*exc_info)
+ self.assertIsNot(exc, exc2)
+ self.assertEqual(exc, exc2)
+ self.assertNotEqual(exc, object())
+ self.assertEqual(exc, ALWAYS_EQ)
+
def test_unhashable(self):
class UnhashableException(Exception):
def __eq__(self, other):
diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py
index ce5bbfc..41f78e7 100644
--- a/Lib/test/test_weakref.py
+++ b/Lib/test/test_weakref.py
@@ -11,7 +11,7 @@ import time
import random
from test import support
-from test.support import script_helper
+from test.support import script_helper, ALWAYS_EQ
# Used in ReferencesTestCase.test_ref_created_during_del() .
ref_from_del = None
@@ -794,6 +794,10 @@ class ReferencesTestCase(TestBase):
self.assertTrue(a != c)
self.assertTrue(a == d)
self.assertFalse(a != d)
+ self.assertFalse(a == x)
+ self.assertTrue(a != x)
+ self.assertTrue(a == ALWAYS_EQ)
+ self.assertFalse(a != ALWAYS_EQ)
del x, y, z
gc.collect()
for r in a, b, c:
@@ -1102,6 +1106,9 @@ class WeakMethodTestCase(unittest.TestCase):
_ne(a, f)
_ne(b, e)
_ne(b, f)
+ # Compare with different types
+ _ne(a, x.some_method)
+ _eq(a, ALWAYS_EQ)
del x, y, z
gc.collect()
# Dead WeakMethods compare by identity
diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py
index 52bacc1..e5c3496 100644
--- a/Lib/test/test_xmlrpc.py
+++ b/Lib/test/test_xmlrpc.py
@@ -15,6 +15,7 @@ import re
import io
import contextlib
from test import support
+from test.support import ALWAYS_EQ, LARGEST, SMALLEST
try:
import gzip
@@ -530,14 +531,10 @@ class DateTimeTestCase(unittest.TestCase):
# some other types
dbytes = dstr.encode('ascii')
dtuple = now.timetuple()
- with self.assertRaises(TypeError):
- dtime == 1970
- with self.assertRaises(TypeError):
- dtime != dbytes
- with self.assertRaises(TypeError):
- dtime == bytearray(dbytes)
- with self.assertRaises(TypeError):
- dtime != dtuple
+ self.assertFalse(dtime == 1970)
+ self.assertTrue(dtime != dbytes)
+ self.assertFalse(dtime == bytearray(dbytes))
+ self.assertTrue(dtime != dtuple)
with self.assertRaises(TypeError):
dtime < float(1970)
with self.assertRaises(TypeError):
@@ -547,6 +544,18 @@ class DateTimeTestCase(unittest.TestCase):
with self.assertRaises(TypeError):
dtime >= dtuple
+ self.assertTrue(dtime == ALWAYS_EQ)
+ self.assertFalse(dtime != ALWAYS_EQ)
+ self.assertTrue(dtime < LARGEST)
+ self.assertFalse(dtime > LARGEST)
+ self.assertTrue(dtime <= LARGEST)
+ self.assertFalse(dtime >= LARGEST)
+ self.assertFalse(dtime < SMALLEST)
+ self.assertTrue(dtime > SMALLEST)
+ self.assertFalse(dtime <= SMALLEST)
+ self.assertTrue(dtime >= SMALLEST)
+
+
class BinaryTestCase(unittest.TestCase):
# XXX What should str(Binary(b"\xff")) return? I'm chosing "\xff"
diff --git a/Lib/tkinter/__init__.py b/Lib/tkinter/__init__.py
index 9626a27..9258484 100644
--- a/Lib/tkinter/__init__.py
+++ b/Lib/tkinter/__init__.py
@@ -484,6 +484,8 @@ class Variable:
Note: if the Variable's master matters to behavior
also compare self._master == other._master
"""
+ if not isinstance(other, Variable):
+ return NotImplemented
return self.__class__.__name__ == other.__class__.__name__ \
and self._name == other._name
diff --git a/Lib/tkinter/font.py b/Lib/tkinter/font.py
index eeff454..15ad7ab 100644
--- a/Lib/tkinter/font.py
+++ b/Lib/tkinter/font.py
@@ -101,7 +101,9 @@ class Font:
return self.name
def __eq__(self, other):
- return isinstance(other, Font) and self.name == other.name
+ if not isinstance(other, Font):
+ return NotImplemented
+ return self.name == other.name
def __getitem__(self, key):
return self.cget(key)
diff --git a/Lib/tkinter/test/test_tkinter/test_font.py b/Lib/tkinter/test/test_tkinter/test_font.py
index 97cd87c..a021ea3 100644
--- a/Lib/tkinter/test/test_tkinter/test_font.py
+++ b/Lib/tkinter/test/test_tkinter/test_font.py
@@ -1,7 +1,7 @@
import unittest
import tkinter
from tkinter import font
-from test.support import requires, run_unittest, gc_collect
+from test.support import requires, run_unittest, gc_collect, ALWAYS_EQ
from tkinter.test.support import AbstractTkTest
requires('gui')
@@ -70,6 +70,7 @@ class FontTest(AbstractTkTest, unittest.TestCase):
self.assertEqual(font1, font2)
self.assertNotEqual(font1, font1.copy())
self.assertNotEqual(font1, 0)
+ self.assertEqual(font1, ALWAYS_EQ)
def test_measure(self):
self.assertIsInstance(self.font.measure('abc'), int)
diff --git a/Lib/tkinter/test/test_tkinter/test_variables.py b/Lib/tkinter/test/test_tkinter/test_variables.py
index 2eb1e12..08b7ded 100644
--- a/Lib/tkinter/test/test_tkinter/test_variables.py
+++ b/Lib/tkinter/test/test_tkinter/test_variables.py
@@ -2,6 +2,7 @@ import unittest
import gc
from tkinter import (Variable, StringVar, IntVar, DoubleVar, BooleanVar, Tcl,
TclError)
+from test.support import ALWAYS_EQ
class Var(Variable):
@@ -59,11 +60,17 @@ class TestVariable(TestBase):
# values doesn't matter, only class and name are checked
v1 = Variable(self.root, name="abc")
v2 = Variable(self.root, name="abc")
+ self.assertIsNot(v1, v2)
self.assertEqual(v1, v2)
- v3 = Variable(self.root, name="abc")
- v4 = StringVar(self.root, name="abc")
- self.assertNotEqual(v3, v4)
+ v3 = StringVar(self.root, name="abc")
+ self.assertNotEqual(v1, v3)
+
+ V = type('Variable', (), {})
+ self.assertNotEqual(v1, V())
+
+ self.assertNotEqual(v1, object())
+ self.assertEqual(v1, ALWAYS_EQ)
def test_invalid_name(self):
with self.assertRaises(TypeError):
diff --git a/Lib/traceback.py b/Lib/traceback.py
index ab35da9..7a4c8e1 100644
--- a/Lib/traceback.py
+++ b/Lib/traceback.py
@@ -538,7 +538,9 @@ class TracebackException:
self.__cause__._load_lines()
def __eq__(self, other):
- return self.__dict__ == other.__dict__
+ if isinstance(other, TracebackException):
+ return self.__dict__ == other.__dict__
+ return NotImplemented
def __str__(self):
return self._str
diff --git a/Lib/tracemalloc.py b/Lib/tracemalloc.py
index 2c1ac3b..80b521c 100644
--- a/Lib/tracemalloc.py
+++ b/Lib/tracemalloc.py
@@ -43,6 +43,8 @@ class Statistic:
return hash((self.traceback, self.size, self.count))
def __eq__(self, other):
+ if not isinstance(other, Statistic):
+ return NotImplemented
return (self.traceback == other.traceback
and self.size == other.size
and self.count == other.count)
@@ -84,6 +86,8 @@ class StatisticDiff:
self.count, self.count_diff))
def __eq__(self, other):
+ if not isinstance(other, StatisticDiff):
+ return NotImplemented
return (self.traceback == other.traceback
and self.size == other.size
and self.size_diff == other.size_diff
@@ -153,9 +157,13 @@ class Frame:
return self._frame[1]
def __eq__(self, other):
+ if not isinstance(other, Frame):
+ return NotImplemented
return (self._frame == other._frame)
def __lt__(self, other):
+ if not isinstance(other, Frame):
+ return NotImplemented
return (self._frame < other._frame)
def __hash__(self):
@@ -200,9 +208,13 @@ class Traceback(Sequence):
return hash(self._frames)
def __eq__(self, other):
+ if not isinstance(other, Traceback):
+ return NotImplemented
return (self._frames == other._frames)
def __lt__(self, other):
+ if not isinstance(other, Traceback):
+ return NotImplemented
return (self._frames < other._frames)
def __str__(self):
@@ -271,6 +283,8 @@ class Trace:
return Traceback(self._trace[2])
def __eq__(self, other):
+ if not isinstance(other, Trace):
+ return NotImplemented
return (self._trace == other._trace)
def __hash__(self):
@@ -303,6 +317,8 @@ class _Traces(Sequence):
return trace._trace in self._traces
def __eq__(self, other):
+ if not isinstance(other, _Traces):
+ return NotImplemented
return (self._traces == other._traces)
def __repr__(self):
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py
index b3dc640..298b41e 100644
--- a/Lib/unittest/mock.py
+++ b/Lib/unittest/mock.py
@@ -2358,12 +2358,10 @@ class _Call(tuple):
def __eq__(self, other):
- if other is ANY:
- return True
try:
len_other = len(other)
except TypeError:
- return False
+ return NotImplemented
self_name = ''
if len(self) == 2:
diff --git a/Lib/unittest/test/testmock/testmock.py b/Lib/unittest/test/testmock/testmock.py
index 18efd31..69b34e9 100644
--- a/Lib/unittest/test/testmock/testmock.py
+++ b/Lib/unittest/test/testmock/testmock.py
@@ -3,6 +3,7 @@ import re
import sys
import tempfile
+from test.support import ALWAYS_EQ
import unittest
from unittest.test.testmock.support import is_instance
from unittest import mock
@@ -322,6 +323,8 @@ class MockTest(unittest.TestCase):
self.assertFalse(mm != mock.ANY)
self.assertTrue(mock.ANY == mm)
self.assertFalse(mock.ANY != mm)
+ self.assertTrue(mm == ALWAYS_EQ)
+ self.assertFalse(mm != ALWAYS_EQ)
call1 = mock.call(mock.MagicMock())
call2 = mock.call(mock.ANY)
@@ -330,6 +333,11 @@ class MockTest(unittest.TestCase):
self.assertTrue(call2 == call1)
self.assertFalse(call2 != call1)
+ self.assertTrue(call1 == ALWAYS_EQ)
+ self.assertFalse(call1 != ALWAYS_EQ)
+ self.assertFalse(call1 == 1)
+ self.assertTrue(call1 != 1)
+
def test_assert_called_with(self):
mock = Mock()
diff --git a/Lib/weakref.py b/Lib/weakref.py
index fa7559b..560deee 100644
--- a/Lib/weakref.py
+++ b/Lib/weakref.py
@@ -75,14 +75,14 @@ class WeakMethod(ref):
if not self._alive or not other._alive:
return self is other
return ref.__eq__(self, other) and self._func_ref == other._func_ref
- return False
+ return NotImplemented
def __ne__(self, other):
if isinstance(other, WeakMethod):
if not self._alive or not other._alive:
return self is not other
return ref.__ne__(self, other) or self._func_ref != other._func_ref
- return True
+ return NotImplemented
__hash__ = ref.__hash__
diff --git a/Lib/xmlrpc/client.py b/Lib/xmlrpc/client.py
index b987574..246ef27 100644
--- a/Lib/xmlrpc/client.py
+++ b/Lib/xmlrpc/client.py
@@ -313,31 +313,38 @@ class DateTime:
s = self.timetuple()
o = other.timetuple()
else:
- otype = (hasattr(other, "__class__")
- and other.__class__.__name__
- or type(other))
- raise TypeError("Can't compare %s and %s" %
- (self.__class__.__name__, otype))
+ s = self
+ o = NotImplemented
return s, o
def __lt__(self, other):
s, o = self.make_comparable(other)
+ if o is NotImplemented:
+ return NotImplemented
return s < o
def __le__(self, other):
s, o = self.make_comparable(other)
+ if o is NotImplemented:
+ return NotImplemented
return s <= o
def __gt__(self, other):
s, o = self.make_comparable(other)
+ if o is NotImplemented:
+ return NotImplemented
return s > o
def __ge__(self, other):
s, o = self.make_comparable(other)
+ if o is NotImplemented:
+ return NotImplemented
return s >= o
def __eq__(self, other):
s, o = self.make_comparable(other)
+ if o is NotImplemented:
+ return NotImplemented
return s == o
def timetuple(self):