summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorNick Coghlan <ncoghlan@gmail.com>2012-07-07 11:43:30 (GMT)
committerNick Coghlan <ncoghlan@gmail.com>2012-07-07 11:43:30 (GMT)
commit7319f69f490f6f787297904dff901500576c886e (patch)
treedc1fb7f257140f5588d19995ce6a676dcd2e5dd6 /Lib
parent79d79a0f290a528db4eea2afb41828ccee21772a (diff)
downloadcpython-7319f69f490f6f787297904dff901500576c886e.zip
cpython-7319f69f490f6f787297904dff901500576c886e.tar.gz
cpython-7319f69f490f6f787297904dff901500576c886e.tar.bz2
Issue 14814: Make the ipaddress code easier to follow by using newer language features (patch by Serhiy Storchaka)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/ipaddress.py158
1 files changed, 67 insertions, 91 deletions
diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py
index 6daa955..bd79e2a 100644
--- a/Lib/ipaddress.py
+++ b/Lib/ipaddress.py
@@ -214,8 +214,10 @@ def _count_righthand_zero_bits(number, bits):
if number == 0:
return bits
for i in range(bits):
- if (number >> i) % 2:
+ if (number >> i) & 1:
return i
+ # All bits of interest were zero, even if there are more in the number
+ return bits
def summarize_address_range(first, last):
@@ -263,20 +265,13 @@ def summarize_address_range(first, last):
first_int = first._ip
last_int = last._ip
while first_int <= last_int:
- nbits = _count_righthand_zero_bits(first_int, ip_bits)
- current = None
- while nbits >= 0:
- addend = 2**nbits - 1
- current = first_int + addend
- nbits -= 1
- if current <= last_int:
- break
- prefix = _get_prefix_length(first_int, current, ip_bits)
- net = ip('%s/%d' % (first, prefix))
+ nbits = min(_count_righthand_zero_bits(first_int, ip_bits),
+ (last_int - first_int + 1).bit_length() - 1)
+ net = ip('%s/%d' % (first, ip_bits - nbits))
yield net
- if current == ip._ALL_ONES:
+ first_int += 1 << nbits
+ if first_int - 1 == ip._ALL_ONES:
break
- first_int = current + 1
first = first.__class__(first_int)
@@ -304,26 +299,28 @@ def _collapse_addresses_recursive(addresses):
passed.
"""
- ret_array = []
- optimized = False
-
- for cur_addr in addresses:
- if not ret_array:
- ret_array.append(cur_addr)
- continue
- if (cur_addr.network_address >= ret_array[-1].network_address and
- cur_addr.broadcast_address <= ret_array[-1].broadcast_address):
- optimized = True
- elif cur_addr == list(ret_array[-1].supernet().subnets())[1]:
- ret_array.append(ret_array.pop().supernet())
- optimized = True
- else:
- ret_array.append(cur_addr)
-
- if optimized:
- return _collapse_addresses_recursive(ret_array)
+ while True:
+ last_addr = None
+ ret_array = []
+ optimized = False
+
+ for cur_addr in addresses:
+ if not ret_array:
+ last_addr = cur_addr
+ ret_array.append(cur_addr)
+ elif (cur_addr.network_address >= last_addr.network_address and
+ cur_addr.broadcast_address <= last_addr.broadcast_address):
+ optimized = True
+ elif cur_addr == list(last_addr.supernet().subnets())[1]:
+ ret_array[-1] = last_addr = last_addr.supernet()
+ optimized = True
+ else:
+ last_addr = cur_addr
+ ret_array.append(cur_addr)
- return ret_array
+ addresses = ret_array
+ if not optimized:
+ return addresses
def collapse_addresses(addresses):
@@ -452,13 +449,7 @@ class _IPAddressBase:
An integer, the prefix length.
"""
- while mask:
- if ip_int & 1 == 1:
- break
- ip_int >>= 1
- mask -= 1
-
- return mask
+ return mask - _count_righthand_zero_bits(ip_int, mask)
def _ip_string_from_prefix(self, prefixlen=None):
"""Turn a prefix length into a dotted decimal string.
@@ -597,18 +588,16 @@ class _BaseNetwork(_IPAddressBase):
or broadcast addresses.
"""
- cur = int(self.network_address) + 1
- bcast = int(self.broadcast_address) - 1
- while cur <= bcast:
- cur += 1
- yield self._address_class(cur - 1)
+ network = int(self.network_address)
+ broadcast = int(self.broadcast_address)
+ for x in range(network + 1, broadcast):
+ yield self._address_class(x)
def __iter__(self):
- cur = int(self.network_address)
- bcast = int(self.broadcast_address)
- while cur <= bcast:
- cur += 1
- yield self._address_class(cur - 1)
+ network = int(self.network_address)
+ broadcast = int(self.broadcast_address)
+ for x in range(network, broadcast + 1):
+ yield self._address_class(x)
def __getitem__(self, n):
network = int(self.network_address)
@@ -998,7 +987,7 @@ class _BaseV4:
_DECIMAL_DIGITS = frozenset('0123456789')
# the valid octets for host and netmasks. only useful for IPv4.
- _valid_mask_octets = set((255, 254, 252, 248, 240, 224, 192, 128, 0))
+ _valid_mask_octets = frozenset((255, 254, 252, 248, 240, 224, 192, 128, 0))
def __init__(self, address):
self._version = 4
@@ -1027,13 +1016,10 @@ class _BaseV4:
if len(octets) != 4:
raise AddressValueError("Expected 4 octets in %r" % ip_str)
- packed_ip = 0
- for oc in octets:
- try:
- packed_ip = (packed_ip << 8) | self._parse_octet(oc)
- except ValueError as exc:
- raise AddressValueError("%s in %r" % (exc, ip_str)) from None
- return packed_ip
+ try:
+ return int.from_bytes(map(self._parse_octet, octets), 'big')
+ except ValueError as exc:
+ raise AddressValueError("%s in %r" % (exc, ip_str)) from None
def _parse_octet(self, octet_str):
"""Convert a decimal octet into an integer.
@@ -1075,11 +1061,7 @@ class _BaseV4:
The IP address as a string in dotted decimal notation.
"""
- octets = []
- for _ in range(4):
- octets.insert(0, str(ip_int & 0xFF))
- ip_int >>= 8
- return '.'.join(octets)
+ return '.'.join(map(str, ip_int.to_bytes(4, 'big')))
def _is_valid_netmask(self, netmask):
"""Verify that the netmask is valid.
@@ -1095,17 +1077,16 @@ class _BaseV4:
"""
mask = netmask.split('.')
if len(mask) == 4:
- for x in mask:
- try:
- if int(x) in self._valid_mask_octets:
- continue
- except ValueError:
- pass
+ try:
+ for x in mask:
+ if int(x) not in self._valid_mask_octets:
+ return False
+ except ValueError:
# Found something that isn't an integer or isn't valid
return False
- if [y for idx, y in enumerate(mask) if idx > 0 and
- y > mask[idx - 1]]:
- return False
+ for idx, y in enumerate(mask):
+ if idx > 0 and y > mask[idx - 1]:
+ return False
return True
try:
netmask = int(netmask)
@@ -1125,7 +1106,7 @@ class _BaseV4:
"""
bits = ip_str.split('.')
try:
- parts = [int(x) for x in bits if int(x) in self._valid_mask_octets]
+ parts = [x for x in map(int, bits) if x in self._valid_mask_octets]
except ValueError:
return False
if len(parts) != len(bits):
@@ -1526,14 +1507,14 @@ class _BaseV6:
# Disregarding the endpoints, find '::' with nothing in between.
# This indicates that a run of zeroes has been skipped.
- try:
- skip_index, = (
- [i for i in range(1, len(parts) - 1) if not parts[i]] or
- [None])
- except ValueError:
- # Can't have more than one '::'
- msg = "At most one '::' permitted in %r" % ip_str
- raise AddressValueError(msg) from None
+ skip_index = None
+ for i in range(1, len(parts) - 1):
+ if not parts[i]:
+ if skip_index is not None:
+ # Can't have more than one '::'
+ msg = "At most one '::' permitted in %r" % ip_str
+ raise AddressValueError(msg)
+ skip_index = i
# parts_hi is the number of parts to copy from above/before the '::'
# parts_lo is the number of parts to copy from below/after the '::'
@@ -1680,9 +1661,7 @@ class _BaseV6:
raise ValueError('IPv6 address is too large')
hex_str = '%032x' % ip_int
- hextets = []
- for x in range(0, 32, 4):
- hextets.append('%x' % int(hex_str[x:x+4], 16))
+ hextets = ['%x' % int(hex_str[x:x+4], 16) for x in range(0, 32, 4)]
hextets = self._compress_hextets(hextets)
return ':'.join(hextets)
@@ -1705,11 +1684,8 @@ class _BaseV6:
ip_str = str(self)
ip_int = self._ip_int_from_string(ip_str)
- parts = []
- for i in range(self._HEXTET_COUNT):
- parts.append('%04x' % (ip_int & 0xFFFF))
- ip_int >>= 16
- parts.reverse()
+ hex_str = '%032x' % ip_int
+ parts = [hex_str[x:x+4] for x in range(0, 32, 4)]
if isinstance(self, (_BaseNetwork, IPv6Interface)):
return '%s/%d' % (':'.join(parts), self.prefixlen)
return ':'.join(parts)
@@ -1756,9 +1732,9 @@ class _BaseV6:
IPv6Network('FE00::/9')]
if isinstance(self, _BaseAddress):
- return len([x for x in reserved_networks if self in x]) > 0
- return len([x for x in reserved_networks if self.network_address in x
- and self.broadcast_address in x]) > 0
+ return any(self in x for x in reserved_networks)
+ return any(self.network_address in x and self.broadcast_address in x
+ for x in reserved_networks)
@property
def is_link_local(self):