diff options
-rw-r--r-- | Lib/ipaddress.py | 122 | ||||
-rw-r--r-- | Lib/test/test_ipaddress.py | 20 |
2 files changed, 48 insertions, 94 deletions
diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py index cc6760b..750e7a1 100644 --- a/Lib/ipaddress.py +++ b/Lib/ipaddress.py @@ -36,34 +36,22 @@ class NetmaskValueError(ValueError): """A Value Error related to the netmask.""" -def ip_address(address, version=None): +def ip_address(address): """Take an IP string/int and return an object of the correct type. Args: address: A string or integer, the IP address. Either IPv4 or IPv6 addresses may be supplied; integers less than 2**32 will be considered to be IPv4 by default. - version: An integer, 4 or 6. If set, don't try to automatically - determine what the IP address type is. Important for things - like ip_address(1), which could be IPv4, '192.0.2.1', or IPv6, - '2001:db8::1'. Returns: An IPv4Address or IPv6Address object. Raises: ValueError: if the *address* passed isn't either a v4 or a v6 - address, or if the version is not None, 4, or 6. + address """ - if version is not None: - if version == 4: - return IPv4Address(address) - elif version == 6: - return IPv6Address(address) - else: - raise ValueError() - try: return IPv4Address(address) except (AddressValueError, NetmaskValueError): @@ -78,35 +66,22 @@ def ip_address(address, version=None): address) -def ip_network(address, version=None, strict=True): +def ip_network(address, strict=True): """Take an IP string/int and return an object of the correct type. Args: address: A string or integer, the IP network. Either IPv4 or IPv6 networks may be supplied; integers less than 2**32 will be considered to be IPv4 by default. - version: An integer, 4 or 6. If set, don't try to automatically - determine what the IP address type is. Important for things - like ip_network(1), which could be IPv4, '192.0.2.1/32', or IPv6, - '2001:db8::1/128'. Returns: An IPv4Network or IPv6Network object. Raises: ValueError: if the string passed isn't either a v4 or a v6 - address. Or if the network has host bits set. Or if the version - is not None, 4, or 6. + address. Or if the network has host bits set. """ - if version is not None: - if version == 4: - return IPv4Network(address, strict) - elif version == 6: - return IPv6Network(address, strict) - else: - raise ValueError() - try: return IPv4Network(address, strict) except (AddressValueError, NetmaskValueError): @@ -121,24 +96,20 @@ def ip_network(address, version=None, strict=True): address) -def ip_interface(address, version=None): +def ip_interface(address): """Take an IP string/int and return an object of the correct type. Args: address: A string or integer, the IP address. Either IPv4 or IPv6 addresses may be supplied; integers less than 2**32 will be considered to be IPv4 by default. - version: An integer, 4 or 6. If set, don't try to automatically - determine what the IP address type is. Important for things - like ip_interface(1), which could be IPv4, '192.0.2.1/32', or IPv6, - '2001:db8::1/128'. Returns: An IPv4Interface or IPv6Interface object. Raises: ValueError: if the string passed isn't either a v4 or a v6 - address. Or if the version is not None, 4, or 6. + address. Notes: The IPv?Interface classes describe an Address on a particular @@ -146,14 +117,6 @@ def ip_interface(address, version=None): and Network classes. """ - if version is not None: - if version == 4: - return IPv4Interface(address) - elif version == 6: - return IPv6Interface(address) - else: - raise ValueError() - try: return IPv4Interface(address) except (AddressValueError, NetmaskValueError): @@ -281,7 +244,7 @@ def summarize_address_range(first, last): If the first and last objects are not the same version. ValueError: If the last object is not greater than the first. - If the version is not 4 or 6. + If the version of the first address is not 4 or 6. """ if (not (isinstance(first, _BaseAddress) and @@ -318,7 +281,7 @@ def summarize_address_range(first, last): if current == ip._ALL_ONES: break first_int = current + 1 - first = ip_address(first_int, version=first._version) + first = first.__class__(first_int) def _collapse_addresses_recursive(addresses): @@ -586,12 +549,12 @@ class _BaseAddress(_IPAddressBase): def __add__(self, other): if not isinstance(other, int): return NotImplemented - return ip_address(int(self) + other, version=self._version) + return self.__class__(int(self) + other) def __sub__(self, other): if not isinstance(other, int): return NotImplemented - return ip_address(int(self) - other, version=self._version) + return self.__class__(int(self) - other) def __repr__(self): return '%s(%r)' % (self.__class__.__name__, str(self)) @@ -612,13 +575,12 @@ class _BaseAddress(_IPAddressBase): class _BaseNetwork(_IPAddressBase): - """A generic IP object. + """A generic IP network object. This IP class contains the version independent methods which are used by networks. """ - def __init__(self, address): self._cache = {} @@ -642,14 +604,14 @@ class _BaseNetwork(_IPAddressBase): bcast = int(self.broadcast_address) - 1 while cur <= bcast: cur += 1 - yield ip_address(cur - 1, version=self._version) + yield self._address_class(cur - 1) def __iter__(self): cur = int(self.network_address) bcast = int(self.broadcast_address) while cur <= bcast: cur += 1 - yield ip_address(cur - 1, version=self._version) + yield self._address_class(cur - 1) def __getitem__(self, n): network = int(self.network_address) @@ -657,12 +619,12 @@ class _BaseNetwork(_IPAddressBase): if n >= 0: if network + n > broadcast: raise IndexError - return ip_address(network + n, version=self._version) + return self._address_class(network + n) else: n += 1 if broadcast + n < network: raise IndexError - return ip_address(broadcast + n, version=self._version) + return self._address_class(broadcast + n) def __lt__(self, other): if self._version != other._version: @@ -746,8 +708,8 @@ class _BaseNetwork(_IPAddressBase): def broadcast_address(self): x = self._cache.get('broadcast_address') if x is None: - x = ip_address(int(self.network_address) | int(self.hostmask), - version=self._version) + x = self._address_class(int(self.network_address) | + int(self.hostmask)) self._cache['broadcast_address'] = x return x @@ -755,15 +717,15 @@ class _BaseNetwork(_IPAddressBase): def hostmask(self): x = self._cache.get('hostmask') if x is None: - x = ip_address(int(self.netmask) ^ self._ALL_ONES, - version=self._version) + x = self._address_class(int(self.netmask) ^ self._ALL_ONES) self._cache['hostmask'] = x return x @property def network(self): - return ip_network('%s/%d' % (str(self.network_address), - self.prefixlen)) + # XXX (ncoghlan): This is redundant now and will likely be removed + return self.__class__('%s/%d' % (str(self.network_address), + self.prefixlen)) @property def with_prefixlen(self): @@ -787,6 +749,10 @@ class _BaseNetwork(_IPAddressBase): raise NotImplementedError('BaseNet has no version') @property + def _address_class(self): + raise NotImplementedError('BaseNet has no associated address class') + + @property def prefixlen(self): return self._prefixlen @@ -840,9 +806,8 @@ class _BaseNetwork(_IPAddressBase): raise StopIteration # Make sure we're comparing the network of other. - other = ip_network('%s/%s' % (str(other.network_address), - str(other.prefixlen)), - version=other._version) + other = other.__class__('%s/%s' % (str(other.network_address), + str(other.prefixlen))) s1, s2 = self.subnets() while s1 != other and s2 != other: @@ -973,9 +938,9 @@ class _BaseNetwork(_IPAddressBase): 'prefix length diff %d is invalid for netblock %s' % ( new_prefixlen, str(self))) - first = ip_network('%s/%s' % (str(self.network_address), - str(self._prefixlen + prefixlen_diff)), - version=self._version) + first = self.__class__('%s/%s' % + (str(self.network_address), + str(self._prefixlen + prefixlen_diff))) yield first current = first @@ -983,16 +948,17 @@ class _BaseNetwork(_IPAddressBase): broadcast = current.broadcast_address if broadcast == self.broadcast_address: return - new_addr = ip_address(int(broadcast) + 1, version=self._version) - current = ip_network('%s/%s' % (str(new_addr), str(new_prefixlen)), - version=self._version) + new_addr = self._address_class(int(broadcast) + 1) + current = self.__class__('%s/%s' % (str(new_addr), + str(new_prefixlen))) yield current def masked(self): """Return the network object with the host bits masked out.""" - return ip_network('%s/%d' % (self.network_address, self._prefixlen), - version=self._version) + # XXX (ncoghlan): This is redundant now and will likely be removed + return self.__class__('%s/%d' % (self.network_address, + self._prefixlen)) def supernet(self, prefixlen_diff=1, new_prefix=None): """The supernet containing the current network. @@ -1030,11 +996,10 @@ class _BaseNetwork(_IPAddressBase): 'current prefixlen is %d, cannot have a prefixlen_diff of %d' % (self.prefixlen, prefixlen_diff)) # TODO (pmoody): optimize this. - t = ip_network('%s/%d' % (str(self.network_address), - self.prefixlen - prefixlen_diff), - version=self._version, strict=False) - return ip_network('%s/%d' % (str(t.network_address), t.prefixlen), - version=t._version) + t = self.__class__('%s/%d' % (str(self.network_address), + self.prefixlen - prefixlen_diff), + strict=False) + return t.__class__('%s/%d' % (str(t.network_address), t.prefixlen)) class _BaseV4(object): @@ -1391,6 +1356,9 @@ class IPv4Network(_BaseV4, _BaseNetwork): .prefixlen: 27 """ + # Class to use when creating address objects + # TODO (ncoghlan): Investigate using IPv4Interface instead + _address_class = IPv4Address # the valid octets for host and netmasks. only useful for IPv4. _valid_mask_octets = set((255, 254, 252, 248, 240, 224, 192, 128, 0)) @@ -2071,6 +2039,10 @@ class IPv6Network(_BaseV6, _BaseNetwork): """ + # Class to use when creating address objects + # TODO (ncoghlan): Investigate using IPv6Interface instead + _address_class = IPv6Address + def __init__(self, address, strict=True): """Instantiate a new IPv6 Network object. diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py index 6bf5174..bf5286b 100644 --- a/Lib/test/test_ipaddress.py +++ b/Lib/test/test_ipaddress.py @@ -780,12 +780,6 @@ class IpaddrUnitTest(unittest.TestCase): self.assertEqual(self.ipv4_address.version, 4) self.assertEqual(self.ipv6_address.version, 6) - with self.assertRaises(ValueError): - ipaddress.ip_address('1', version=[]) - - with self.assertRaises(ValueError): - ipaddress.ip_address('1', version=5) - def testMaxPrefixLength(self): self.assertEqual(self.ipv4_interface.max_prefixlen, 32) self.assertEqual(self.ipv6_interface.max_prefixlen, 128) @@ -1052,12 +1046,7 @@ class IpaddrUnitTest(unittest.TestCase): def testForceVersion(self): self.assertEqual(ipaddress.ip_network(1).version, 4) - self.assertEqual(ipaddress.ip_network(1, version=6).version, 6) - - with self.assertRaises(ValueError): - ipaddress.ip_network(1, version='l') - with self.assertRaises(ValueError): - ipaddress.ip_network(1, version=3) + self.assertEqual(ipaddress.IPv6Network(1).version, 6) def testWithStar(self): self.assertEqual(str(self.ipv4_interface.with_prefixlen), "1.2.3.4/24") @@ -1148,13 +1137,6 @@ class IpaddrUnitTest(unittest.TestCase): sixtofouraddr.sixtofour) self.assertFalse(bad_addr.sixtofour) - def testIpInterfaceVersion(self): - with self.assertRaises(ValueError): - ipaddress.ip_interface(1, version=123) - - with self.assertRaises(ValueError): - ipaddress.ip_interface(1, version='') - if __name__ == '__main__': unittest.main() |