summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/ipaddress.py47
-rw-r--r--Lib/test/test_ipaddress.py47
2 files changed, 49 insertions, 45 deletions
diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py
index 2ba98d8..6c225f3 100644
--- a/Lib/ipaddress.py
+++ b/Lib/ipaddress.py
@@ -382,40 +382,7 @@ def get_mixed_type_key(obj):
return NotImplemented
-class _TotalOrderingMixin:
- # Helper that derives the other comparison operations from
- # __lt__ and __eq__
- # We avoid functools.total_ordering because it doesn't handle
- # NotImplemented correctly yet (http://bugs.python.org/issue10042)
- def __eq__(self, other):
- raise NotImplementedError
- def __ne__(self, other):
- equal = self.__eq__(other)
- if equal is NotImplemented:
- return NotImplemented
- return not equal
- def __lt__(self, other):
- raise NotImplementedError
- def __le__(self, other):
- less = self.__lt__(other)
- if less is NotImplemented or not less:
- return self.__eq__(other)
- return less
- def __gt__(self, other):
- less = self.__lt__(other)
- if less is NotImplemented:
- return NotImplemented
- equal = self.__eq__(other)
- if equal is NotImplemented:
- return NotImplemented
- return not (less or equal)
- def __ge__(self, other):
- less = self.__lt__(other)
- if less is NotImplemented:
- return NotImplemented
- return not less
-
-class _IPAddressBase(_TotalOrderingMixin):
+class _IPAddressBase:
"""The mother class."""
@@ -567,6 +534,7 @@ class _IPAddressBase(_TotalOrderingMixin):
return self.__class__, (str(self),)
+@functools.total_ordering
class _BaseAddress(_IPAddressBase):
"""A generic IP object.
@@ -586,12 +554,11 @@ class _BaseAddress(_IPAddressBase):
return NotImplemented
def __lt__(self, other):
+ if not isinstance(other, _BaseAddress):
+ return NotImplemented
if self._version != other._version:
raise TypeError('%s and %s are not of the same version' % (
self, other))
- if not isinstance(other, _BaseAddress):
- raise TypeError('%s and %s are not of the same type' % (
- self, other))
if self._ip != other._ip:
return self._ip < other._ip
return False
@@ -624,6 +591,7 @@ class _BaseAddress(_IPAddressBase):
return self.__class__, (self._ip,)
+@functools.total_ordering
class _BaseNetwork(_IPAddressBase):
"""A generic IP network object.
@@ -673,12 +641,11 @@ class _BaseNetwork(_IPAddressBase):
return self._address_class(broadcast + n)
def __lt__(self, other):
+ if not isinstance(other, _BaseNetwork):
+ return NotImplemented
if self._version != other._version:
raise TypeError('%s and %s are not of the same version' % (
self, other))
- if not isinstance(other, _BaseNetwork):
- raise TypeError('%s and %s are not of the same type' % (
- self, other))
if self.network_address != other.network_address:
return self.network_address < other.network_address
if self.netmask != other.netmask:
diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py
index 5ec2cd4..e985329 100644
--- a/Lib/test/test_ipaddress.py
+++ b/Lib/test/test_ipaddress.py
@@ -7,6 +7,7 @@
import unittest
import re
import contextlib
+import functools
import operator
import pickle
import ipaddress
@@ -552,6 +553,20 @@ class FactoryFunctionErrors(BaseTestCase):
self.assertFactoryError(ipaddress.ip_network, "network")
+@functools.total_ordering
+class LargestObject:
+ def __eq__(self, other):
+ return isinstance(other, LargestObject)
+ def __lt__(self, other):
+ return False
+
+@functools.total_ordering
+class SmallestObject:
+ def __eq__(self, other):
+ return isinstance(other, SmallestObject)
+ def __gt__(self, other):
+ return False
+
class ComparisonTests(unittest.TestCase):
v4addr = ipaddress.IPv4Address(1)
@@ -605,6 +620,28 @@ class ComparisonTests(unittest.TestCase):
self.assertRaises(TypeError, lambda: lhs <= rhs)
self.assertRaises(TypeError, lambda: lhs >= rhs)
+ def test_foreign_type_ordering(self):
+ other = object()
+ smallest = SmallestObject()
+ largest = LargestObject()
+ for obj in self.objects:
+ with self.assertRaises(TypeError):
+ obj < other
+ with self.assertRaises(TypeError):
+ obj > other
+ with self.assertRaises(TypeError):
+ obj <= other
+ with self.assertRaises(TypeError):
+ obj >= other
+ self.assertTrue(obj < largest)
+ self.assertFalse(obj > largest)
+ self.assertTrue(obj <= largest)
+ self.assertFalse(obj >= largest)
+ self.assertFalse(obj < smallest)
+ self.assertTrue(obj > smallest)
+ self.assertFalse(obj <= smallest)
+ self.assertTrue(obj >= smallest)
+
def test_mixed_type_key(self):
# with get_mixed_type_key, you can sort addresses and network.
v4_ordered = [self.v4addr, self.v4net, self.v4intf]
@@ -625,7 +662,7 @@ class ComparisonTests(unittest.TestCase):
v4addr = ipaddress.ip_address('1.1.1.1')
v4net = ipaddress.ip_network('1.1.1.1')
v6addr = ipaddress.ip_address('::1')
- v6net = ipaddress.ip_address('::1')
+ v6net = ipaddress.ip_network('::1')
self.assertRaises(TypeError, v4addr.__lt__, v6addr)
self.assertRaises(TypeError, v4addr.__gt__, v6addr)
@@ -1383,10 +1420,10 @@ class IpaddrUnitTest(unittest.TestCase):
unsorted = [ip4, ip1, ip3, ip2]
unsorted.sort()
self.assertEqual(sorted, unsorted)
- self.assertRaises(TypeError, ip1.__lt__,
- ipaddress.ip_address('10.10.10.0'))
- self.assertRaises(TypeError, ip2.__lt__,
- ipaddress.ip_address('10.10.10.0'))
+ self.assertIs(ip1.__lt__(ipaddress.ip_address('10.10.10.0')),
+ NotImplemented)
+ self.assertIs(ip2.__lt__(ipaddress.ip_address('10.10.10.0')),
+ NotImplemented)
# <=, >=
self.assertTrue(ipaddress.ip_network('1.1.1.1') <=