diff options
author | Nick Coghlan <ncoghlan@gmail.com> | 2012-07-07 11:43:30 (GMT) |
---|---|---|
committer | Nick Coghlan <ncoghlan@gmail.com> | 2012-07-07 11:43:30 (GMT) |
commit | 7319f69f490f6f787297904dff901500576c886e (patch) | |
tree | dc1fb7f257140f5588d19995ce6a676dcd2e5dd6 /Lib | |
parent | 79d79a0f290a528db4eea2afb41828ccee21772a (diff) | |
download | cpython-7319f69f490f6f787297904dff901500576c886e.zip cpython-7319f69f490f6f787297904dff901500576c886e.tar.gz cpython-7319f69f490f6f787297904dff901500576c886e.tar.bz2 |
Issue 14814: Make the ipaddress code easier to follow by using newer language features (patch by Serhiy Storchaka)
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/ipaddress.py | 158 |
1 files changed, 67 insertions, 91 deletions
diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py index 6daa955..bd79e2a 100644 --- a/Lib/ipaddress.py +++ b/Lib/ipaddress.py @@ -214,8 +214,10 @@ def _count_righthand_zero_bits(number, bits): if number == 0: return bits for i in range(bits): - if (number >> i) % 2: + if (number >> i) & 1: return i + # All bits of interest were zero, even if there are more in the number + return bits def summarize_address_range(first, last): @@ -263,20 +265,13 @@ def summarize_address_range(first, last): first_int = first._ip last_int = last._ip while first_int <= last_int: - nbits = _count_righthand_zero_bits(first_int, ip_bits) - current = None - while nbits >= 0: - addend = 2**nbits - 1 - current = first_int + addend - nbits -= 1 - if current <= last_int: - break - prefix = _get_prefix_length(first_int, current, ip_bits) - net = ip('%s/%d' % (first, prefix)) + nbits = min(_count_righthand_zero_bits(first_int, ip_bits), + (last_int - first_int + 1).bit_length() - 1) + net = ip('%s/%d' % (first, ip_bits - nbits)) yield net - if current == ip._ALL_ONES: + first_int += 1 << nbits + if first_int - 1 == ip._ALL_ONES: break - first_int = current + 1 first = first.__class__(first_int) @@ -304,26 +299,28 @@ def _collapse_addresses_recursive(addresses): passed. """ - ret_array = [] - optimized = False - - for cur_addr in addresses: - if not ret_array: - ret_array.append(cur_addr) - continue - if (cur_addr.network_address >= ret_array[-1].network_address and - cur_addr.broadcast_address <= ret_array[-1].broadcast_address): - optimized = True - elif cur_addr == list(ret_array[-1].supernet().subnets())[1]: - ret_array.append(ret_array.pop().supernet()) - optimized = True - else: - ret_array.append(cur_addr) - - if optimized: - return _collapse_addresses_recursive(ret_array) + while True: + last_addr = None + ret_array = [] + optimized = False + + for cur_addr in addresses: + if not ret_array: + last_addr = cur_addr + ret_array.append(cur_addr) + elif (cur_addr.network_address >= last_addr.network_address and + cur_addr.broadcast_address <= last_addr.broadcast_address): + optimized = True + elif cur_addr == list(last_addr.supernet().subnets())[1]: + ret_array[-1] = last_addr = last_addr.supernet() + optimized = True + else: + last_addr = cur_addr + ret_array.append(cur_addr) - return ret_array + addresses = ret_array + if not optimized: + return addresses def collapse_addresses(addresses): @@ -452,13 +449,7 @@ class _IPAddressBase: An integer, the prefix length. """ - while mask: - if ip_int & 1 == 1: - break - ip_int >>= 1 - mask -= 1 - - return mask + return mask - _count_righthand_zero_bits(ip_int, mask) def _ip_string_from_prefix(self, prefixlen=None): """Turn a prefix length into a dotted decimal string. @@ -597,18 +588,16 @@ class _BaseNetwork(_IPAddressBase): or broadcast addresses. """ - cur = int(self.network_address) + 1 - bcast = int(self.broadcast_address) - 1 - while cur <= bcast: - cur += 1 - yield self._address_class(cur - 1) + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in range(network + 1, broadcast): + yield self._address_class(x) def __iter__(self): - cur = int(self.network_address) - bcast = int(self.broadcast_address) - while cur <= bcast: - cur += 1 - yield self._address_class(cur - 1) + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in range(network, broadcast + 1): + yield self._address_class(x) def __getitem__(self, n): network = int(self.network_address) @@ -998,7 +987,7 @@ class _BaseV4: _DECIMAL_DIGITS = frozenset('0123456789') # the valid octets for host and netmasks. only useful for IPv4. - _valid_mask_octets = set((255, 254, 252, 248, 240, 224, 192, 128, 0)) + _valid_mask_octets = frozenset((255, 254, 252, 248, 240, 224, 192, 128, 0)) def __init__(self, address): self._version = 4 @@ -1027,13 +1016,10 @@ class _BaseV4: if len(octets) != 4: raise AddressValueError("Expected 4 octets in %r" % ip_str) - packed_ip = 0 - for oc in octets: - try: - packed_ip = (packed_ip << 8) | self._parse_octet(oc) - except ValueError as exc: - raise AddressValueError("%s in %r" % (exc, ip_str)) from None - return packed_ip + try: + return int.from_bytes(map(self._parse_octet, octets), 'big') + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) from None def _parse_octet(self, octet_str): """Convert a decimal octet into an integer. @@ -1075,11 +1061,7 @@ class _BaseV4: The IP address as a string in dotted decimal notation. """ - octets = [] - for _ in range(4): - octets.insert(0, str(ip_int & 0xFF)) - ip_int >>= 8 - return '.'.join(octets) + return '.'.join(map(str, ip_int.to_bytes(4, 'big'))) def _is_valid_netmask(self, netmask): """Verify that the netmask is valid. @@ -1095,17 +1077,16 @@ class _BaseV4: """ mask = netmask.split('.') if len(mask) == 4: - for x in mask: - try: - if int(x) in self._valid_mask_octets: - continue - except ValueError: - pass + try: + for x in mask: + if int(x) not in self._valid_mask_octets: + return False + except ValueError: # Found something that isn't an integer or isn't valid return False - if [y for idx, y in enumerate(mask) if idx > 0 and - y > mask[idx - 1]]: - return False + for idx, y in enumerate(mask): + if idx > 0 and y > mask[idx - 1]: + return False return True try: netmask = int(netmask) @@ -1125,7 +1106,7 @@ class _BaseV4: """ bits = ip_str.split('.') try: - parts = [int(x) for x in bits if int(x) in self._valid_mask_octets] + parts = [x for x in map(int, bits) if x in self._valid_mask_octets] except ValueError: return False if len(parts) != len(bits): @@ -1526,14 +1507,14 @@ class _BaseV6: # Disregarding the endpoints, find '::' with nothing in between. # This indicates that a run of zeroes has been skipped. - try: - skip_index, = ( - [i for i in range(1, len(parts) - 1) if not parts[i]] or - [None]) - except ValueError: - # Can't have more than one '::' - msg = "At most one '::' permitted in %r" % ip_str - raise AddressValueError(msg) from None + skip_index = None + for i in range(1, len(parts) - 1): + if not parts[i]: + if skip_index is not None: + # Can't have more than one '::' + msg = "At most one '::' permitted in %r" % ip_str + raise AddressValueError(msg) + skip_index = i # parts_hi is the number of parts to copy from above/before the '::' # parts_lo is the number of parts to copy from below/after the '::' @@ -1680,9 +1661,7 @@ class _BaseV6: raise ValueError('IPv6 address is too large') hex_str = '%032x' % ip_int - hextets = [] - for x in range(0, 32, 4): - hextets.append('%x' % int(hex_str[x:x+4], 16)) + hextets = ['%x' % int(hex_str[x:x+4], 16) for x in range(0, 32, 4)] hextets = self._compress_hextets(hextets) return ':'.join(hextets) @@ -1705,11 +1684,8 @@ class _BaseV6: ip_str = str(self) ip_int = self._ip_int_from_string(ip_str) - parts = [] - for i in range(self._HEXTET_COUNT): - parts.append('%04x' % (ip_int & 0xFFFF)) - ip_int >>= 16 - parts.reverse() + hex_str = '%032x' % ip_int + parts = [hex_str[x:x+4] for x in range(0, 32, 4)] if isinstance(self, (_BaseNetwork, IPv6Interface)): return '%s/%d' % (':'.join(parts), self.prefixlen) return ':'.join(parts) @@ -1756,9 +1732,9 @@ class _BaseV6: IPv6Network('FE00::/9')] if isinstance(self, _BaseAddress): - return len([x for x in reserved_networks if self in x]) > 0 - return len([x for x in reserved_networks if self.network_address in x - and self.broadcast_address in x]) > 0 + return any(self in x for x in reserved_networks) + return any(self.network_address in x and self.broadcast_address in x + for x in reserved_networks) @property def is_link_local(self): |