summaryrefslogtreecommitdiffstats
path: root/Lib/collections
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2020-06-11 06:17:58 (GMT)
committerGitHub <noreply@github.com>2020-06-11 06:17:58 (GMT)
commit31d17798d6567036d3ac2771555a919b3628962f (patch)
tree0a418e6a07c3655aa68fdf39c380cc6e72c3aaa8 /Lib/collections
parent896f4cf63f9ab93e30572d879a5719d5aa2499fb (diff)
downloadcpython-31d17798d6567036d3ac2771555a919b3628962f.zip
cpython-31d17798d6567036d3ac2771555a919b3628962f.tar.gz
cpython-31d17798d6567036d3ac2771555a919b3628962f.tar.bz2
Collections module reformatting and minor code refactoring (GH-20772)
Diffstat (limited to 'Lib/collections')
-rw-r--r--Lib/collections/__init__.py362
1 files changed, 271 insertions, 91 deletions
diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py
index 6a06cc6..42d0ec7 100644
--- a/Lib/collections/__init__.py
+++ b/Lib/collections/__init__.py
@@ -14,17 +14,30 @@ list, set, and tuple.
'''
-__all__ = ['deque', 'defaultdict', 'namedtuple', 'UserDict', 'UserList',
- 'UserString', 'Counter', 'OrderedDict', 'ChainMap']
+__all__ = [
+ 'ChainMap',
+ 'Counter',
+ 'OrderedDict',
+ 'UserDict',
+ 'UserList',
+ 'UserString',
+ 'defaultdict',
+ 'deque',
+ 'namedtuple',
+]
import _collections_abc
-from operator import itemgetter as _itemgetter, eq as _eq
-from keyword import iskeyword as _iskeyword
-import sys as _sys
import heapq as _heapq
-from _weakref import proxy as _proxy
-from itertools import repeat as _repeat, chain as _chain, starmap as _starmap
+import sys as _sys
+
+from itertools import chain as _chain
+from itertools import repeat as _repeat
+from itertools import starmap as _starmap
+from keyword import iskeyword as _iskeyword
+from operator import eq as _eq
+from operator import itemgetter as _itemgetter
from reprlib import recursive_repr as _recursive_repr
+from _weakref import proxy as _proxy
try:
from _collections import deque
@@ -54,6 +67,7 @@ def __getattr__(name):
return obj
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
+
################################################################################
### OrderedDict
################################################################################
@@ -408,10 +422,13 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non
# Create all the named tuple methods to be added to the class namespace
- s = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))'
- namespace = {'_tuple_new': tuple_new, '__builtins__': None,
- '__name__': f'namedtuple_{typename}'}
- __new__ = eval(s, namespace)
+ namespace = {
+ '_tuple_new': tuple_new,
+ '__builtins__': None,
+ '__name__': f'namedtuple_{typename}',
+ }
+ code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))'
+ __new__ = eval(code, namespace)
__new__.__name__ = '__new__'
__new__.__doc__ = f'Create new instance of {typename}({arg_list})'
if defaults is not None:
@@ -449,8 +466,14 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non
return _tuple(self)
# Modify function metadata to help with introspection and debugging
- for method in (__new__, _make.__func__, _replace,
- __repr__, _asdict, __getnewargs__):
+ for method in (
+ __new__,
+ _make.__func__,
+ _replace,
+ __repr__,
+ _asdict,
+ __getnewargs__,
+ ):
method.__qualname__ = f'{typename}.{method.__name__}'
# Build-up the class namespace dictionary
@@ -566,7 +589,7 @@ class Counter(dict):
>>> c = Counter(a=4, b=2) # a new counter from keyword args
'''
- super(Counter, self).__init__()
+ super().__init__()
self.update(iterable, **kwds)
def __missing__(self, key):
@@ -650,7 +673,8 @@ class Counter(dict):
for elem, count in iterable.items():
self[elem] = count + self_get(elem, 0)
else:
- super(Counter, self).update(iterable) # fast path when counter is empty
+ # fast path when counter is empty
+ super().update(iterable)
else:
_count_elements(self, iterable)
if kwds:
@@ -733,13 +757,14 @@ class Counter(dict):
def __repr__(self):
if not self:
- return '%s()' % self.__class__.__name__
+ return f'{self.__class__.__name__}()'
try:
- items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
- return '%s({%s})' % (self.__class__.__name__, items)
+ # dict() preserves the ordering returned by most_common()
+ d = dict(self.most_common())
except TypeError:
# handle case where values are not orderable
- return '{0}({1!r})'.format(self.__class__.__name__, dict(self))
+ d = dict(self)
+ return f'{self.__class__.__name__}({d!r})'
# Multiset-style mathematical operations discussed in:
# Knuth TAOCP Volume II section 4.6.3 exercise 19
@@ -1018,7 +1043,7 @@ class ChainMap(_collections_abc.MutableMapping):
try:
del self.maps[0][key]
except KeyError:
- raise KeyError('Key not found in the first mapping: {!r}'.format(key))
+ raise KeyError(f'Key not found in the first mapping: {key!r}')
def popitem(self):
'Remove and return an item pair from maps[0]. Raise KeyError is maps[0] is empty.'
@@ -1032,30 +1057,30 @@ class ChainMap(_collections_abc.MutableMapping):
try:
return self.maps[0].pop(key, *args)
except KeyError:
- raise KeyError('Key not found in the first mapping: {!r}'.format(key))
+ raise KeyError(f'Key not found in the first mapping: {key!r}')
def clear(self):
'Clear maps[0], leaving maps[1:] intact.'
self.maps[0].clear()
def __ior__(self, other):
- self.maps[0] |= other
+ self.maps[0].update(other)
return self
def __or__(self, other):
- if isinstance(other, _collections_abc.Mapping):
- m = self.maps[0].copy()
- m.update(other)
- return self.__class__(m, *self.maps[1:])
- return NotImplemented
+ if not isinstance(other, _collections_abc.Mapping):
+ return NotImplemented
+ m = self.copy()
+ m.maps[0].update(other)
+ return m
def __ror__(self, other):
- if isinstance(other, _collections_abc.Mapping):
- m = dict(other)
- for child in reversed(self.maps):
- m.update(child)
- return self.__class__(m)
- return NotImplemented
+ if not isinstance(other, _collections_abc.Mapping):
+ return NotImplemented
+ m = dict(other)
+ for child in reversed(self.maps):
+ m.update(child)
+ return self.__class__(m)
################################################################################
@@ -1072,15 +1097,22 @@ class UserDict(_collections_abc.MutableMapping):
if kwargs:
self.update(kwargs)
- def __len__(self): return len(self.data)
+ def __len__(self):
+ return len(self.data)
+
def __getitem__(self, key):
if key in self.data:
return self.data[key]
if hasattr(self.__class__, "__missing__"):
return self.__class__.__missing__(self, key)
raise KeyError(key)
- def __setitem__(self, key, item): self.data[key] = item
- def __delitem__(self, key): del self.data[key]
+
+ def __setitem__(self, key, item):
+ self.data[key] = item
+
+ def __delitem__(self, key):
+ del self.data[key]
+
def __iter__(self):
return iter(self.data)
@@ -1089,7 +1121,8 @@ class UserDict(_collections_abc.MutableMapping):
return key in self.data
# Now, add the methods in dicts but not in MutableMapping
- def __repr__(self): return repr(self.data)
+ def __repr__(self):
+ return repr(self.data)
def __or__(self, other):
if isinstance(other, UserDict):
@@ -1097,12 +1130,14 @@ class UserDict(_collections_abc.MutableMapping):
if isinstance(other, dict):
return self.__class__(self.data | other)
return NotImplemented
+
def __ror__(self, other):
if isinstance(other, UserDict):
return self.__class__(other.data | self.data)
if isinstance(other, dict):
return self.__class__(other | self.data)
return NotImplemented
+
def __ior__(self, other):
if isinstance(other, UserDict):
self.data |= other.data
@@ -1138,13 +1173,13 @@ class UserDict(_collections_abc.MutableMapping):
return d
-
################################################################################
### UserList
################################################################################
class UserList(_collections_abc.MutableSequence):
"""A more or less complete user-defined wrapper around list objects."""
+
def __init__(self, initlist=None):
self.data = []
if initlist is not None:
@@ -1155,35 +1190,60 @@ class UserList(_collections_abc.MutableSequence):
self.data[:] = initlist.data[:]
else:
self.data = list(initlist)
- def __repr__(self): return repr(self.data)
- def __lt__(self, other): return self.data < self.__cast(other)
- def __le__(self, other): return self.data <= self.__cast(other)
- def __eq__(self, other): return self.data == self.__cast(other)
- def __gt__(self, other): return self.data > self.__cast(other)
- def __ge__(self, other): return self.data >= self.__cast(other)
+
+ def __repr__(self):
+ return repr(self.data)
+
+ def __lt__(self, other):
+ return self.data < self.__cast(other)
+
+ def __le__(self, other):
+ return self.data <= self.__cast(other)
+
+ def __eq__(self, other):
+ return self.data == self.__cast(other)
+
+ def __gt__(self, other):
+ return self.data > self.__cast(other)
+
+ def __ge__(self, other):
+ return self.data >= self.__cast(other)
+
def __cast(self, other):
return other.data if isinstance(other, UserList) else other
- def __contains__(self, item): return item in self.data
- def __len__(self): return len(self.data)
+
+ def __contains__(self, item):
+ return item in self.data
+
+ def __len__(self):
+ return len(self.data)
+
def __getitem__(self, i):
if isinstance(i, slice):
return self.__class__(self.data[i])
else:
return self.data[i]
- def __setitem__(self, i, item): self.data[i] = item
- def __delitem__(self, i): del self.data[i]
+
+ def __setitem__(self, i, item):
+ self.data[i] = item
+
+ def __delitem__(self, i):
+ del self.data[i]
+
def __add__(self, other):
if isinstance(other, UserList):
return self.__class__(self.data + other.data)
elif isinstance(other, type(self.data)):
return self.__class__(self.data + other)
return self.__class__(self.data + list(other))
+
def __radd__(self, other):
if isinstance(other, UserList):
return self.__class__(other.data + self.data)
elif isinstance(other, type(self.data)):
return self.__class__(other + self.data)
return self.__class__(list(other) + self.data)
+
def __iadd__(self, other):
if isinstance(other, UserList):
self.data += other.data
@@ -1192,28 +1252,53 @@ class UserList(_collections_abc.MutableSequence):
else:
self.data += list(other)
return self
+
def __mul__(self, n):
- return self.__class__(self.data*n)
+ return self.__class__(self.data * n)
+
__rmul__ = __mul__
+
def __imul__(self, n):
self.data *= n
return self
+
def __copy__(self):
inst = self.__class__.__new__(self.__class__)
inst.__dict__.update(self.__dict__)
# Create a copy and avoid triggering descriptors
inst.__dict__["data"] = self.__dict__["data"][:]
return inst
- def append(self, item): self.data.append(item)
- def insert(self, i, item): self.data.insert(i, item)
- def pop(self, i=-1): return self.data.pop(i)
- def remove(self, item): self.data.remove(item)
- def clear(self): self.data.clear()
- def copy(self): return self.__class__(self)
- def count(self, item): return self.data.count(item)
- def index(self, item, *args): return self.data.index(item, *args)
- def reverse(self): self.data.reverse()
- def sort(self, /, *args, **kwds): self.data.sort(*args, **kwds)
+
+ def append(self, item):
+ self.data.append(item)
+
+ def insert(self, i, item):
+ self.data.insert(i, item)
+
+ def pop(self, i=-1):
+ return self.data.pop(i)
+
+ def remove(self, item):
+ self.data.remove(item)
+
+ def clear(self):
+ self.data.clear()
+
+ def copy(self):
+ return self.__class__(self)
+
+ def count(self, item):
+ return self.data.count(item)
+
+ def index(self, item, *args):
+ return self.data.index(item, *args)
+
+ def reverse(self):
+ self.data.reverse()
+
+ def sort(self, /, *args, **kwds):
+ self.data.sort(*args, **kwds)
+
def extend(self, other):
if isinstance(other, UserList):
self.data.extend(other.data)
@@ -1221,12 +1306,12 @@ class UserList(_collections_abc.MutableSequence):
self.data.extend(other)
-
################################################################################
### UserString
################################################################################
class UserString(_collections_abc.Sequence):
+
def __init__(self, seq):
if isinstance(seq, str):
self.data = seq
@@ -1234,12 +1319,25 @@ class UserString(_collections_abc.Sequence):
self.data = seq.data[:]
else:
self.data = str(seq)
- def __str__(self): return str(self.data)
- def __repr__(self): return repr(self.data)
- def __int__(self): return int(self.data)
- def __float__(self): return float(self.data)
- def __complex__(self): return complex(self.data)
- def __hash__(self): return hash(self.data)
+
+ def __str__(self):
+ return str(self.data)
+
+ def __repr__(self):
+ return repr(self.data)
+
+ def __int__(self):
+ return int(self.data)
+
+ def __float__(self):
+ return float(self.data)
+
+ def __complex__(self):
+ return complex(self.data)
+
+ def __hash__(self):
+ return hash(self.data)
+
def __getnewargs__(self):
return (self.data[:],)
@@ -1247,18 +1345,22 @@ class UserString(_collections_abc.Sequence):
if isinstance(string, UserString):
return self.data == string.data
return self.data == string
+
def __lt__(self, string):
if isinstance(string, UserString):
return self.data < string.data
return self.data < string
+
def __le__(self, string):
if isinstance(string, UserString):
return self.data <= string.data
return self.data <= string
+
def __gt__(self, string):
if isinstance(string, UserString):
return self.data > string.data
return self.data > string
+
def __ge__(self, string):
if isinstance(string, UserString):
return self.data >= string.data
@@ -1269,110 +1371,188 @@ class UserString(_collections_abc.Sequence):
char = char.data
return char in self.data
- def __len__(self): return len(self.data)
- def __getitem__(self, index): return self.__class__(self.data[index])
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, index):
+ return self.__class__(self.data[index])
+
def __add__(self, other):
if isinstance(other, UserString):
return self.__class__(self.data + other.data)
elif isinstance(other, str):
return self.__class__(self.data + other)
return self.__class__(self.data + str(other))
+
def __radd__(self, other):
if isinstance(other, str):
return self.__class__(other + self.data)
return self.__class__(str(other) + self.data)
+
def __mul__(self, n):
- return self.__class__(self.data*n)
+ return self.__class__(self.data * n)
+
__rmul__ = __mul__
+
def __mod__(self, args):
return self.__class__(self.data % args)
+
def __rmod__(self, template):
return self.__class__(str(template) % self)
+
# the following methods are defined in alphabetical order:
- def capitalize(self): return self.__class__(self.data.capitalize())
+ def capitalize(self):
+ return self.__class__(self.data.capitalize())
+
def casefold(self):
return self.__class__(self.data.casefold())
+
def center(self, width, *args):
return self.__class__(self.data.center(width, *args))
+
def count(self, sub, start=0, end=_sys.maxsize):
if isinstance(sub, UserString):
sub = sub.data
return self.data.count(sub, start, end)
+
def removeprefix(self, prefix, /):
if isinstance(prefix, UserString):
prefix = prefix.data
return self.__class__(self.data.removeprefix(prefix))
+
def removesuffix(self, suffix, /):
if isinstance(suffix, UserString):
suffix = suffix.data
return self.__class__(self.data.removesuffix(suffix))
+
def encode(self, encoding='utf-8', errors='strict'):
encoding = 'utf-8' if encoding is None else encoding
errors = 'strict' if errors is None else errors
return self.data.encode(encoding, errors)
+
def endswith(self, suffix, start=0, end=_sys.maxsize):
return self.data.endswith(suffix, start, end)
+
def expandtabs(self, tabsize=8):
return self.__class__(self.data.expandtabs(tabsize))
+
def find(self, sub, start=0, end=_sys.maxsize):
if isinstance(sub, UserString):
sub = sub.data
return self.data.find(sub, start, end)
+
def format(self, /, *args, **kwds):
return self.data.format(*args, **kwds)
+
def format_map(self, mapping):
return self.data.format_map(mapping)
+
def index(self, sub, start=0, end=_sys.maxsize):
return self.data.index(sub, start, end)
- def isalpha(self): return self.data.isalpha()
- def isalnum(self): return self.data.isalnum()
- def isascii(self): return self.data.isascii()
- def isdecimal(self): return self.data.isdecimal()
- def isdigit(self): return self.data.isdigit()
- def isidentifier(self): return self.data.isidentifier()
- def islower(self): return self.data.islower()
- def isnumeric(self): return self.data.isnumeric()
- def isprintable(self): return self.data.isprintable()
- def isspace(self): return self.data.isspace()
- def istitle(self): return self.data.istitle()
- def isupper(self): return self.data.isupper()
- def join(self, seq): return self.data.join(seq)
+
+ def isalpha(self):
+ return self.data.isalpha()
+
+ def isalnum(self):
+ return self.data.isalnum()
+
+ def isascii(self):
+ return self.data.isascii()
+
+ def isdecimal(self):
+ return self.data.isdecimal()
+
+ def isdigit(self):
+ return self.data.isdigit()
+
+ def isidentifier(self):
+ return self.data.isidentifier()
+
+ def islower(self):
+ return self.data.islower()
+
+ def isnumeric(self):
+ return self.data.isnumeric()
+
+ def isprintable(self):
+ return self.data.isprintable()
+
+ def isspace(self):
+ return self.data.isspace()
+
+ def istitle(self):
+ return self.data.istitle()
+
+ def isupper(self):
+ return self.data.isupper()
+
+ def join(self, seq):
+ return self.data.join(seq)
+
def ljust(self, width, *args):
return self.__class__(self.data.ljust(width, *args))
- def lower(self): return self.__class__(self.data.lower())
- def lstrip(self, chars=None): return self.__class__(self.data.lstrip(chars))
+
+ def lower(self):
+ return self.__class__(self.data.lower())
+
+ def lstrip(self, chars=None):
+ return self.__class__(self.data.lstrip(chars))
+
maketrans = str.maketrans
+
def partition(self, sep):
return self.data.partition(sep)
+
def replace(self, old, new, maxsplit=-1):
if isinstance(old, UserString):
old = old.data
if isinstance(new, UserString):
new = new.data
return self.__class__(self.data.replace(old, new, maxsplit))
+
def rfind(self, sub, start=0, end=_sys.maxsize):
if isinstance(sub, UserString):
sub = sub.data
return self.data.rfind(sub, start, end)
+
def rindex(self, sub, start=0, end=_sys.maxsize):
return self.data.rindex(sub, start, end)
+
def rjust(self, width, *args):
return self.__class__(self.data.rjust(width, *args))
+
def rpartition(self, sep):
return self.data.rpartition(sep)
+
def rstrip(self, chars=None):
return self.__class__(self.data.rstrip(chars))
+
def split(self, sep=None, maxsplit=-1):
return self.data.split(sep, maxsplit)
+
def rsplit(self, sep=None, maxsplit=-1):
return self.data.rsplit(sep, maxsplit)
- def splitlines(self, keepends=False): return self.data.splitlines(keepends)
+
+ def splitlines(self, keepends=False):
+ return self.data.splitlines(keepends)
+
def startswith(self, prefix, start=0, end=_sys.maxsize):
return self.data.startswith(prefix, start, end)
- def strip(self, chars=None): return self.__class__(self.data.strip(chars))
- def swapcase(self): return self.__class__(self.data.swapcase())
- def title(self): return self.__class__(self.data.title())
+
+ def strip(self, chars=None):
+ return self.__class__(self.data.strip(chars))
+
+ def swapcase(self):
+ return self.__class__(self.data.swapcase())
+
+ def title(self):
+ return self.__class__(self.data.title())
+
def translate(self, *args):
return self.__class__(self.data.translate(*args))
- def upper(self): return self.__class__(self.data.upper())
- def zfill(self, width): return self.__class__(self.data.zfill(width))
+
+ def upper(self):
+ return self.__class__(self.data.upper())
+
+ def zfill(self, width):
+ return self.__class__(self.data.zfill(width))