diff options
author | Nick Coghlan <ncoghlan@gmail.com> | 2012-07-07 13:05:59 (GMT) |
---|---|---|
committer | Nick Coghlan <ncoghlan@gmail.com> | 2012-07-07 13:05:59 (GMT) |
commit | 9a9c28ce7a051b37a91e4fc7aef70bcdcda25047 (patch) | |
tree | 23e80dcda26b4e1448de538eabf2800d9752851f /Lib | |
parent | d46f7d209b0d6db48f63c8317df9bfefbed73ae7 (diff) | |
download | cpython-9a9c28ce7a051b37a91e4fc7aef70bcdcda25047.zip cpython-9a9c28ce7a051b37a91e4fc7aef70bcdcda25047.tar.gz cpython-9a9c28ce7a051b37a91e4fc7aef70bcdcda25047.tar.bz2 |
Issue 14814: Correctly return NotImplemented from ipaddress._BaseNetwork.__eq__
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/ipaddress.py | 12 | ||||
-rw-r--r-- | Lib/test/test_ipaddress.py | 17 |
2 files changed, 22 insertions, 7 deletions
diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py index e788c0a5..b1e07fc 100644 --- a/Lib/ipaddress.py +++ b/Lib/ipaddress.py @@ -651,12 +651,12 @@ class _BaseNetwork(_IPAddressBase): return not lt def __eq__(self, other): - if not isinstance(other, _BaseNetwork): - raise TypeError('%s and %s are not of the same type' % ( - self, other)) - return (self._version == other._version and - self.network_address == other.network_address and - int(self.netmask) == int(other.netmask)) + try: + return (self._version == other._version and + self.network_address == other.network_address and + int(self.netmask) == int(other.netmask)) + except AttributeError: + return NotImplemented def __ne__(self, other): eq = self.__eq__(other) diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py index 2ac37e1..417c986 100644 --- a/Lib/test/test_ipaddress.py +++ b/Lib/test/test_ipaddress.py @@ -462,7 +462,6 @@ class IpaddrUnitTest(unittest.TestCase): self.assertEqual(128, ipaddress._count_righthand_zero_bits(0, 128)) self.assertEqual("IPv4Network('1.2.3.0/24')", repr(self.ipv4_network)) self.assertEqual('0x1020318', hex(self.ipv4_network)) - self.assertRaises(TypeError, self.ipv4_network.__eq__, object()) def testMissingAddressVersion(self): class Broken(ipaddress._BaseAddress): @@ -496,6 +495,22 @@ class IpaddrUnitTest(unittest.TestCase): self.assertEqual(str(self.ipv6_network.hostmask), '::ffff:ffff:ffff:ffff') + def testEqualityChecks(self): + # __eq__ should never raise TypeError directly + other = object() + def assertEqualityNotImplemented(instance): + self.assertEqual(instance.__eq__(other), NotImplemented) + self.assertEqual(instance.__ne__(other), NotImplemented) + self.assertFalse(instance == other) + self.assertTrue(instance != other) + + assertEqualityNotImplemented(self.ipv4_address) + assertEqualityNotImplemented(self.ipv4_network) + assertEqualityNotImplemented(self.ipv4_interface) + assertEqualityNotImplemented(self.ipv6_address) + assertEqualityNotImplemented(self.ipv6_network) + assertEqualityNotImplemented(self.ipv6_interface) + def testBadVersionComparison(self): # These should always raise TypeError v4addr = ipaddress.ip_address('1.1.1.1') |