diff options
author | Raymond Hettinger <rhettinger@users.noreply.github.com> | 2020-06-11 06:17:58 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-06-11 06:17:58 (GMT) |
commit | 31d17798d6567036d3ac2771555a919b3628962f (patch) | |
tree | 0a418e6a07c3655aa68fdf39c380cc6e72c3aaa8 /Lib/collections | |
parent | 896f4cf63f9ab93e30572d879a5719d5aa2499fb (diff) | |
download | cpython-31d17798d6567036d3ac2771555a919b3628962f.zip cpython-31d17798d6567036d3ac2771555a919b3628962f.tar.gz cpython-31d17798d6567036d3ac2771555a919b3628962f.tar.bz2 |
Collections module reformatting and minor code refactoring (GH-20772)
Diffstat (limited to 'Lib/collections')
-rw-r--r-- | Lib/collections/__init__.py | 362 |
1 files changed, 271 insertions, 91 deletions
diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py index 6a06cc6..42d0ec7 100644 --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -14,17 +14,30 @@ list, set, and tuple. ''' -__all__ = ['deque', 'defaultdict', 'namedtuple', 'UserDict', 'UserList', - 'UserString', 'Counter', 'OrderedDict', 'ChainMap'] +__all__ = [ + 'ChainMap', + 'Counter', + 'OrderedDict', + 'UserDict', + 'UserList', + 'UserString', + 'defaultdict', + 'deque', + 'namedtuple', +] import _collections_abc -from operator import itemgetter as _itemgetter, eq as _eq -from keyword import iskeyword as _iskeyword -import sys as _sys import heapq as _heapq -from _weakref import proxy as _proxy -from itertools import repeat as _repeat, chain as _chain, starmap as _starmap +import sys as _sys + +from itertools import chain as _chain +from itertools import repeat as _repeat +from itertools import starmap as _starmap +from keyword import iskeyword as _iskeyword +from operator import eq as _eq +from operator import itemgetter as _itemgetter from reprlib import recursive_repr as _recursive_repr +from _weakref import proxy as _proxy try: from _collections import deque @@ -54,6 +67,7 @@ def __getattr__(name): return obj raise AttributeError(f'module {__name__!r} has no attribute {name!r}') + ################################################################################ ### OrderedDict ################################################################################ @@ -408,10 +422,13 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non # Create all the named tuple methods to be added to the class namespace - s = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))' - namespace = {'_tuple_new': tuple_new, '__builtins__': None, - '__name__': f'namedtuple_{typename}'} - __new__ = eval(s, namespace) + namespace = { + '_tuple_new': tuple_new, + '__builtins__': None, + '__name__': f'namedtuple_{typename}', + } + code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))' + __new__ = eval(code, namespace) __new__.__name__ = '__new__' __new__.__doc__ = f'Create new instance of {typename}({arg_list})' if defaults is not None: @@ -449,8 +466,14 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non return _tuple(self) # Modify function metadata to help with introspection and debugging - for method in (__new__, _make.__func__, _replace, - __repr__, _asdict, __getnewargs__): + for method in ( + __new__, + _make.__func__, + _replace, + __repr__, + _asdict, + __getnewargs__, + ): method.__qualname__ = f'{typename}.{method.__name__}' # Build-up the class namespace dictionary @@ -566,7 +589,7 @@ class Counter(dict): >>> c = Counter(a=4, b=2) # a new counter from keyword args ''' - super(Counter, self).__init__() + super().__init__() self.update(iterable, **kwds) def __missing__(self, key): @@ -650,7 +673,8 @@ class Counter(dict): for elem, count in iterable.items(): self[elem] = count + self_get(elem, 0) else: - super(Counter, self).update(iterable) # fast path when counter is empty + # fast path when counter is empty + super().update(iterable) else: _count_elements(self, iterable) if kwds: @@ -733,13 +757,14 @@ class Counter(dict): def __repr__(self): if not self: - return '%s()' % self.__class__.__name__ + return f'{self.__class__.__name__}()' try: - items = ', '.join(map('%r: %r'.__mod__, self.most_common())) - return '%s({%s})' % (self.__class__.__name__, items) + # dict() preserves the ordering returned by most_common() + d = dict(self.most_common()) except TypeError: # handle case where values are not orderable - return '{0}({1!r})'.format(self.__class__.__name__, dict(self)) + d = dict(self) + return f'{self.__class__.__name__}({d!r})' # Multiset-style mathematical operations discussed in: # Knuth TAOCP Volume II section 4.6.3 exercise 19 @@ -1018,7 +1043,7 @@ class ChainMap(_collections_abc.MutableMapping): try: del self.maps[0][key] except KeyError: - raise KeyError('Key not found in the first mapping: {!r}'.format(key)) + raise KeyError(f'Key not found in the first mapping: {key!r}') def popitem(self): 'Remove and return an item pair from maps[0]. Raise KeyError is maps[0] is empty.' @@ -1032,30 +1057,30 @@ class ChainMap(_collections_abc.MutableMapping): try: return self.maps[0].pop(key, *args) except KeyError: - raise KeyError('Key not found in the first mapping: {!r}'.format(key)) + raise KeyError(f'Key not found in the first mapping: {key!r}') def clear(self): 'Clear maps[0], leaving maps[1:] intact.' self.maps[0].clear() def __ior__(self, other): - self.maps[0] |= other + self.maps[0].update(other) return self def __or__(self, other): - if isinstance(other, _collections_abc.Mapping): - m = self.maps[0].copy() - m.update(other) - return self.__class__(m, *self.maps[1:]) - return NotImplemented + if not isinstance(other, _collections_abc.Mapping): + return NotImplemented + m = self.copy() + m.maps[0].update(other) + return m def __ror__(self, other): - if isinstance(other, _collections_abc.Mapping): - m = dict(other) - for child in reversed(self.maps): - m.update(child) - return self.__class__(m) - return NotImplemented + if not isinstance(other, _collections_abc.Mapping): + return NotImplemented + m = dict(other) + for child in reversed(self.maps): + m.update(child) + return self.__class__(m) ################################################################################ @@ -1072,15 +1097,22 @@ class UserDict(_collections_abc.MutableMapping): if kwargs: self.update(kwargs) - def __len__(self): return len(self.data) + def __len__(self): + return len(self.data) + def __getitem__(self, key): if key in self.data: return self.data[key] if hasattr(self.__class__, "__missing__"): return self.__class__.__missing__(self, key) raise KeyError(key) - def __setitem__(self, key, item): self.data[key] = item - def __delitem__(self, key): del self.data[key] + + def __setitem__(self, key, item): + self.data[key] = item + + def __delitem__(self, key): + del self.data[key] + def __iter__(self): return iter(self.data) @@ -1089,7 +1121,8 @@ class UserDict(_collections_abc.MutableMapping): return key in self.data # Now, add the methods in dicts but not in MutableMapping - def __repr__(self): return repr(self.data) + def __repr__(self): + return repr(self.data) def __or__(self, other): if isinstance(other, UserDict): @@ -1097,12 +1130,14 @@ class UserDict(_collections_abc.MutableMapping): if isinstance(other, dict): return self.__class__(self.data | other) return NotImplemented + def __ror__(self, other): if isinstance(other, UserDict): return self.__class__(other.data | self.data) if isinstance(other, dict): return self.__class__(other | self.data) return NotImplemented + def __ior__(self, other): if isinstance(other, UserDict): self.data |= other.data @@ -1138,13 +1173,13 @@ class UserDict(_collections_abc.MutableMapping): return d - ################################################################################ ### UserList ################################################################################ class UserList(_collections_abc.MutableSequence): """A more or less complete user-defined wrapper around list objects.""" + def __init__(self, initlist=None): self.data = [] if initlist is not None: @@ -1155,35 +1190,60 @@ class UserList(_collections_abc.MutableSequence): self.data[:] = initlist.data[:] else: self.data = list(initlist) - def __repr__(self): return repr(self.data) - def __lt__(self, other): return self.data < self.__cast(other) - def __le__(self, other): return self.data <= self.__cast(other) - def __eq__(self, other): return self.data == self.__cast(other) - def __gt__(self, other): return self.data > self.__cast(other) - def __ge__(self, other): return self.data >= self.__cast(other) + + def __repr__(self): + return repr(self.data) + + def __lt__(self, other): + return self.data < self.__cast(other) + + def __le__(self, other): + return self.data <= self.__cast(other) + + def __eq__(self, other): + return self.data == self.__cast(other) + + def __gt__(self, other): + return self.data > self.__cast(other) + + def __ge__(self, other): + return self.data >= self.__cast(other) + def __cast(self, other): return other.data if isinstance(other, UserList) else other - def __contains__(self, item): return item in self.data - def __len__(self): return len(self.data) + + def __contains__(self, item): + return item in self.data + + def __len__(self): + return len(self.data) + def __getitem__(self, i): if isinstance(i, slice): return self.__class__(self.data[i]) else: return self.data[i] - def __setitem__(self, i, item): self.data[i] = item - def __delitem__(self, i): del self.data[i] + + def __setitem__(self, i, item): + self.data[i] = item + + def __delitem__(self, i): + del self.data[i] + def __add__(self, other): if isinstance(other, UserList): return self.__class__(self.data + other.data) elif isinstance(other, type(self.data)): return self.__class__(self.data + other) return self.__class__(self.data + list(other)) + def __radd__(self, other): if isinstance(other, UserList): return self.__class__(other.data + self.data) elif isinstance(other, type(self.data)): return self.__class__(other + self.data) return self.__class__(list(other) + self.data) + def __iadd__(self, other): if isinstance(other, UserList): self.data += other.data @@ -1192,28 +1252,53 @@ class UserList(_collections_abc.MutableSequence): else: self.data += list(other) return self + def __mul__(self, n): - return self.__class__(self.data*n) + return self.__class__(self.data * n) + __rmul__ = __mul__ + def __imul__(self, n): self.data *= n return self + def __copy__(self): inst = self.__class__.__new__(self.__class__) inst.__dict__.update(self.__dict__) # Create a copy and avoid triggering descriptors inst.__dict__["data"] = self.__dict__["data"][:] return inst - def append(self, item): self.data.append(item) - def insert(self, i, item): self.data.insert(i, item) - def pop(self, i=-1): return self.data.pop(i) - def remove(self, item): self.data.remove(item) - def clear(self): self.data.clear() - def copy(self): return self.__class__(self) - def count(self, item): return self.data.count(item) - def index(self, item, *args): return self.data.index(item, *args) - def reverse(self): self.data.reverse() - def sort(self, /, *args, **kwds): self.data.sort(*args, **kwds) + + def append(self, item): + self.data.append(item) + + def insert(self, i, item): + self.data.insert(i, item) + + def pop(self, i=-1): + return self.data.pop(i) + + def remove(self, item): + self.data.remove(item) + + def clear(self): + self.data.clear() + + def copy(self): + return self.__class__(self) + + def count(self, item): + return self.data.count(item) + + def index(self, item, *args): + return self.data.index(item, *args) + + def reverse(self): + self.data.reverse() + + def sort(self, /, *args, **kwds): + self.data.sort(*args, **kwds) + def extend(self, other): if isinstance(other, UserList): self.data.extend(other.data) @@ -1221,12 +1306,12 @@ class UserList(_collections_abc.MutableSequence): self.data.extend(other) - ################################################################################ ### UserString ################################################################################ class UserString(_collections_abc.Sequence): + def __init__(self, seq): if isinstance(seq, str): self.data = seq @@ -1234,12 +1319,25 @@ class UserString(_collections_abc.Sequence): self.data = seq.data[:] else: self.data = str(seq) - def __str__(self): return str(self.data) - def __repr__(self): return repr(self.data) - def __int__(self): return int(self.data) - def __float__(self): return float(self.data) - def __complex__(self): return complex(self.data) - def __hash__(self): return hash(self.data) + + def __str__(self): + return str(self.data) + + def __repr__(self): + return repr(self.data) + + def __int__(self): + return int(self.data) + + def __float__(self): + return float(self.data) + + def __complex__(self): + return complex(self.data) + + def __hash__(self): + return hash(self.data) + def __getnewargs__(self): return (self.data[:],) @@ -1247,18 +1345,22 @@ class UserString(_collections_abc.Sequence): if isinstance(string, UserString): return self.data == string.data return self.data == string + def __lt__(self, string): if isinstance(string, UserString): return self.data < string.data return self.data < string + def __le__(self, string): if isinstance(string, UserString): return self.data <= string.data return self.data <= string + def __gt__(self, string): if isinstance(string, UserString): return self.data > string.data return self.data > string + def __ge__(self, string): if isinstance(string, UserString): return self.data >= string.data @@ -1269,110 +1371,188 @@ class UserString(_collections_abc.Sequence): char = char.data return char in self.data - def __len__(self): return len(self.data) - def __getitem__(self, index): return self.__class__(self.data[index]) + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.__class__(self.data[index]) + def __add__(self, other): if isinstance(other, UserString): return self.__class__(self.data + other.data) elif isinstance(other, str): return self.__class__(self.data + other) return self.__class__(self.data + str(other)) + def __radd__(self, other): if isinstance(other, str): return self.__class__(other + self.data) return self.__class__(str(other) + self.data) + def __mul__(self, n): - return self.__class__(self.data*n) + return self.__class__(self.data * n) + __rmul__ = __mul__ + def __mod__(self, args): return self.__class__(self.data % args) + def __rmod__(self, template): return self.__class__(str(template) % self) + # the following methods are defined in alphabetical order: - def capitalize(self): return self.__class__(self.data.capitalize()) + def capitalize(self): + return self.__class__(self.data.capitalize()) + def casefold(self): return self.__class__(self.data.casefold()) + def center(self, width, *args): return self.__class__(self.data.center(width, *args)) + def count(self, sub, start=0, end=_sys.maxsize): if isinstance(sub, UserString): sub = sub.data return self.data.count(sub, start, end) + def removeprefix(self, prefix, /): if isinstance(prefix, UserString): prefix = prefix.data return self.__class__(self.data.removeprefix(prefix)) + def removesuffix(self, suffix, /): if isinstance(suffix, UserString): suffix = suffix.data return self.__class__(self.data.removesuffix(suffix)) + def encode(self, encoding='utf-8', errors='strict'): encoding = 'utf-8' if encoding is None else encoding errors = 'strict' if errors is None else errors return self.data.encode(encoding, errors) + def endswith(self, suffix, start=0, end=_sys.maxsize): return self.data.endswith(suffix, start, end) + def expandtabs(self, tabsize=8): return self.__class__(self.data.expandtabs(tabsize)) + def find(self, sub, start=0, end=_sys.maxsize): if isinstance(sub, UserString): sub = sub.data return self.data.find(sub, start, end) + def format(self, /, *args, **kwds): return self.data.format(*args, **kwds) + def format_map(self, mapping): return self.data.format_map(mapping) + def index(self, sub, start=0, end=_sys.maxsize): return self.data.index(sub, start, end) - def isalpha(self): return self.data.isalpha() - def isalnum(self): return self.data.isalnum() - def isascii(self): return self.data.isascii() - def isdecimal(self): return self.data.isdecimal() - def isdigit(self): return self.data.isdigit() - def isidentifier(self): return self.data.isidentifier() - def islower(self): return self.data.islower() - def isnumeric(self): return self.data.isnumeric() - def isprintable(self): return self.data.isprintable() - def isspace(self): return self.data.isspace() - def istitle(self): return self.data.istitle() - def isupper(self): return self.data.isupper() - def join(self, seq): return self.data.join(seq) + + def isalpha(self): + return self.data.isalpha() + + def isalnum(self): + return self.data.isalnum() + + def isascii(self): + return self.data.isascii() + + def isdecimal(self): + return self.data.isdecimal() + + def isdigit(self): + return self.data.isdigit() + + def isidentifier(self): + return self.data.isidentifier() + + def islower(self): + return self.data.islower() + + def isnumeric(self): + return self.data.isnumeric() + + def isprintable(self): + return self.data.isprintable() + + def isspace(self): + return self.data.isspace() + + def istitle(self): + return self.data.istitle() + + def isupper(self): + return self.data.isupper() + + def join(self, seq): + return self.data.join(seq) + def ljust(self, width, *args): return self.__class__(self.data.ljust(width, *args)) - def lower(self): return self.__class__(self.data.lower()) - def lstrip(self, chars=None): return self.__class__(self.data.lstrip(chars)) + + def lower(self): + return self.__class__(self.data.lower()) + + def lstrip(self, chars=None): + return self.__class__(self.data.lstrip(chars)) + maketrans = str.maketrans + def partition(self, sep): return self.data.partition(sep) + def replace(self, old, new, maxsplit=-1): if isinstance(old, UserString): old = old.data if isinstance(new, UserString): new = new.data return self.__class__(self.data.replace(old, new, maxsplit)) + def rfind(self, sub, start=0, end=_sys.maxsize): if isinstance(sub, UserString): sub = sub.data return self.data.rfind(sub, start, end) + def rindex(self, sub, start=0, end=_sys.maxsize): return self.data.rindex(sub, start, end) + def rjust(self, width, *args): return self.__class__(self.data.rjust(width, *args)) + def rpartition(self, sep): return self.data.rpartition(sep) + def rstrip(self, chars=None): return self.__class__(self.data.rstrip(chars)) + def split(self, sep=None, maxsplit=-1): return self.data.split(sep, maxsplit) + def rsplit(self, sep=None, maxsplit=-1): return self.data.rsplit(sep, maxsplit) - def splitlines(self, keepends=False): return self.data.splitlines(keepends) + + def splitlines(self, keepends=False): + return self.data.splitlines(keepends) + def startswith(self, prefix, start=0, end=_sys.maxsize): return self.data.startswith(prefix, start, end) - def strip(self, chars=None): return self.__class__(self.data.strip(chars)) - def swapcase(self): return self.__class__(self.data.swapcase()) - def title(self): return self.__class__(self.data.title()) + + def strip(self, chars=None): + return self.__class__(self.data.strip(chars)) + + def swapcase(self): + return self.__class__(self.data.swapcase()) + + def title(self): + return self.__class__(self.data.title()) + def translate(self, *args): return self.__class__(self.data.translate(*args)) - def upper(self): return self.__class__(self.data.upper()) - def zfill(self, width): return self.__class__(self.data.zfill(width)) + + def upper(self): + return self.__class__(self.data.upper()) + + def zfill(self, width): + return self.__class__(self.data.zfill(width)) |