__all__ = ['deque', 'defaultdict', 'namedtuple', 'UserDict', 'UserList', 'UserString', 'Counter'] # For bootstrapping reasons, the collection ABCs are defined in _abcoll.py. # They should however be considered an integral part of collections.py. from _abcoll import * import _abcoll __all__ += _abcoll.__all__ from _collections import deque, defaultdict from operator import itemgetter as _itemgetter from keyword import iskeyword as _iskeyword import sys as _sys import heapq as _heapq from itertools import repeat as _repeat, chain as _chain, starmap as _starmap ################################################################################ ### namedtuple ################################################################################ def namedtuple(typename, field_names, verbose=False, rename=False): """Returns a new subclass of tuple with named fields. >>> Point = namedtuple('Point', 'x y') >>> Point.__doc__ # docstring for the new class 'Point(x, y)' >>> p = Point(11, y=22) # instantiate with positional args or keywords >>> p[0] + p[1] # indexable like a plain tuple 33 >>> x, y = p # unpack like a regular tuple >>> x, y (11, 22) >>> p.x + p.y # fields also accessable by name 33 >>> d = p._asdict() # convert to a dictionary >>> d['x'] 11 >>> Point(**d) # convert from a dictionary Point(x=11, y=22) >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields Point(x=100, y=22) """ # Parse and validate the field names. Validation serves two purposes, # generating informative error messages and preventing template injection attacks. if isinstance(field_names, str): field_names = field_names.replace(',', ' ').split() # names separated by whitespace and/or commas field_names = tuple(map(str, field_names)) if rename: names = list(field_names) seen = set() for i, name in enumerate(names): if (not all(c.isalnum() or c=='_' for c in name) or _iskeyword(name) or not name or name[0].isdigit() or name.startswith('_') or name in seen): names[i] = '_%d' % (i+1) seen.add(name) field_names = tuple(names) for name in (typename,) + field_names: if not all(c.isalnum() or c=='_' for c in name): raise ValueError('Type names and field names can only contain alphanumeric characters and underscores: %r' % name) if _iskeyword(name): raise ValueError('Type names and field names cannot be a keyword: %r' % name) if name[0].isdigit(): raise ValueError('Type names and field names cannot start with a number: %r' % name) seen_names = set() for name in field_names: if name.startswith('_') and not rename: raise ValueError('Field names cannot start with an underscore: %r' % name) if name in seen_names: raise ValueError('Encountered duplicate field name: %r' % name) seen_names.add(name) # Create and fill-in the class template numfields = len(field_names) argtxt = repr(field_names).replace("'", "")[1:-1] # tuple repr without parens or quotes reprtxt = ', '.join('%s=%%r' % name for name in field_names) dicttxt = ', '.join('%r: t[%d]' % (name, pos) for pos, name in enumerate(field_names)) template = '''class %(typename)s(tuple): '%(typename)s(%(argtxt)s)' \n __slots__ = () \n _fields = %(field_names)r \n def __new__(cls, %(argtxt)s): return tuple.__new__(cls, (%(argtxt)s)) \n @classmethod def _make(cls, iterable, new=tuple.__new__, len=len): 'Make a new %(typename)s object from a sequence or iterable' result = new(cls, iterable) if len(result) != %(numfields)d: raise TypeError('Expected %(numfields)d arguments, got %%d' %% len(result)) return result \n def __repr__(self): return '%(typename)s(%(reprtxt)s)' %% self \n def _asdict(t): 'Return a new dict which maps field names to their values' return {%(dicttxt)s} \n def _replace(self, **kwds): 'Return a new %(typename)s object replacing specified fields with new values' result = self._make(map(kwds.pop, %(field_names)r, self)) if kwds: raise ValueError('Got unexpected field names: %%r' %% kwds.keys()) return result \n def __getnewargs__(self): return tuple(self) \n\n''' % locals() for i, name in enumerate(field_names): template += ' %s = property(itemgetter(%d))\n' % (name, i) if verbose: print(template) # Execute the template string in a temporary namespace and # support tracing utilities by setting a value for frame.f_globals['__name__'] namespace = dict(itemgetter=_itemgetter, __name__='namedtuple_%s' % typename) try: exec(template, namespace) except SyntaxError as e: raise SyntaxError(e.msg + ':\n' + template) from e result = namespace[typename] # For pickling to work, the __module__ variable needs to be set to the frame # where the named tuple is created. Bypass this step in enviroments where # sys._getframe is not defined (Jython for example). if hasattr(_sys, '_getframe'): result.__module__ = _sys._getframe(1).f_globals.get('__name__', '__main__') return result ######################################################################## ### Counter ######################################################################## class Counter(dict): '''Dict subclass for counting hashable items. Sometimes called a bag or multiset. Elements are stored as dictionary keys and their counts are stored as dictionary values. >>> c = Counter('abracadabra') # count elements from a string >>> c.most_common(3) # three most common elements [('a', 5), ('r', 2), ('b', 2)] >>> sorted(c) # list all unique elements ['a', 'b', 'c', 'd', 'r'] >>> ''.join(sorted(c.elements())) # list elements with repetitions 'aaaaabbcdrr' >>> sum(c.values()) # total of all counts 11 >>> c['a'] # count of letter 'a' 5 >>> for elem in 'shazam': # update counts from an iterable ... c[elem] += 1 # by adding 1 to each element's count >>> c['a'] # now there are seven 'a' 7 >>> del c['r'] # remove all 'r' >>> c['r'] # now there are zero 'r' 0 >>> d = Counter('simsalabim') # make another counter >>> c.update(d) # add in the second counter >>> c['a'] # now there are nine 'a' 9 >>> c.clear() # empty the counter >>> c Counter() Note: If a count is set to zero or reduced to zero, it will remain in the counter until the entry is deleted or the counter is cleared: >>> c = Counter('aaabbc') >>> c['b'] -= 2 # reduce the count of 'b' by two >>> c.most_common() # 'b' is still in, but its count is zero [('a', 3), ('c', 1), ('b', 0)] ''' # References: # http://en.wikipedia.org/wiki/Multiset # http://www.gnu.org/software/smalltalk/manual-base/html_node/Bag.html # http://www.demo2s.com/Tutorial/Cpp/0380__set-multiset/Catalog0380__set-multiset.htm # http://code.activestate.com/recipes/259174/ # Knuth, TAOCP Vol. II section 4.6.3 def __init__(self, iterable=None, **kwds): '''Create a new, empty Counter object. And if given, count elements from an input iterable. Or, initialize the count from another mapping of elements to their counts. >>> c = Counter() # a new, empty counter >>> c = Counter('gallahad') # a new counter from an iterable >>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping >>> c = Counter(a=4, b=2) # a new counter from keyword args ''' self.update(iterable, **kwds) def __missing__(self, key): 'The count of elements not in the Counter is zero.' # Needed so that self[missing_item] does not raise KeyError return 0 def most_common(self, n=None): '''List the n most common elements and their counts from the most common to the least. If n is None, then list all element counts. >>> Counter('abracadabra').most_common(3) [('a', 5), ('r', 2), ('b', 2)] ''' # Emulate Bag.sortedByCount from Smalltalk if n is None: return sorted(self.items(), key=_itemgetter(1), reverse=True) return _heapq.nlargest(n, self.items(), key=_itemgetter(1)) def elements(self): '''Iterator over elements repeating each as many times as its count. >>> c = Counter('ABCABC') >>> sorted(c.elements()) ['A', 'A', 'B', 'B', 'C', 'C'] # Knuth's example for prime factors of 1836: 2**2 * 3**3 * 17**1 >>> prime_factors = Counter({2: 2, 3: 3, 17: 1}) >>> product = 1 >>> for factor in prime_factors.elements(): # loop over factors ... product *= factor # and multiply them >>> product 1836 Note, if an element's count has been set to zero or is a negative number, elements() will ignore it. ''' # Emulate Bag.do from Smalltalk and Multiset.begin from C++. return _chain.from_iterable(_starmap(_repeat, self.items())) # Override dict methods where necessary @classmethod def fromkeys(cls, iterable, v=None): # There is no equivalent method for counters because setting v=1 # means that no element can have a count greater than one. raise NotImplementedError( 'Counter.fromkeys() is undefined. Use Counter(iterable) instead.') def update(self, iterable=None, **kwds): '''Like dict.update() but add counts instead of replacing them. Source can be an iterable, a dictionary, or another Counter instance. >>> c = Counter('which') >>> c.update('witch') # add elements from another iterable >>> d = Counter('watch') >>> c.update(d) # add elements from another counter >>> c['h'] # four 'h' in which, witch, and watch 4 ''' # The regular dict.update() operation makes no sense here because the # replace behavior results in the some of original untouched counts # being mixed-in with all of the other counts for a mismash that # doesn't have a straight-forward interpretation in most counting # contexts. Instead, we implement straight-addition. Both the inputs # and outputs are allowed to contain zero and negative counts. if iterable is not None: if isinstance(iterable, Mapping): if self: for elem, count in iterable.items(): self[elem] += count else: dict.update(self, iterable) # fast path when counter is empty else: for elem in iterable: self[elem] += 1 if kwds: self.update(kwds) def copy(self): 'Like dict.copy() but returns a Counter instance instead of a dict.' return Counter(self) def __delitem__(self, elem): 'Like dict.__delitem__() but does not raise KeyError for missing values.' if elem in self: dict.__delitem__(self, elem) def __repr__(self): if not self: return '%s()' % self.__class__.__name__ items = ', '.join(map('%r: %r'.__mod__, self.most_common())) return '%s({%s})' % (self.__class__.__name__, items) # Multiset-style mathematical operations discussed in: # Knuth TAOCP Volume II section 4.6.3 exercise 19 # and at http://en.wikipedia.org/wiki/Multiset # # Outputs guaranteed to only include positive counts. # # To strip negative and zero counts, add-in an empty counter: # c += Counter() def __add__(self, other): '''Add counts from two counters. >>> Counter('abbb') + Counter('bcc') Counter({'b': 4, 'c': 2, 'a': 1}) ''' if not isinstance(other, Counter): return NotImplemented result = Counter() for elem in set(self) | set(other): newcount = self[elem] + other[elem] if newcount > 0: result[elem] = newcount return result def __sub__(self, other): ''' Subtract count, but keep only results with positive counts. >>> Counter('abbbc') - Counter('bccd') Counter({'b': 2, 'a': 1}) ''' if not isinstance(other, Counter): return NotImplemented result = Counter() for elem in set(self) | set(other): newcount = self[elem] - other[elem] if newcount > 0: result[elem] = newcount return result def __or__(self, other): '''Union is the maximum of value in either of the input counters. >>> Counter('abbb') | Counter('bcc') Counter({'b': 3, 'c': 2, 'a': 1}) ''' if not isinstance(other, Counter): return NotImplemented _max = max result = Counter() for elem in set(self) | set(other): newcount = _max(self[elem], other[elem]) if newcount > 0: result[elem] = newcount return result def __and__(self, other): ''' Intersection is the minimum of corresponding counts. >>> Counter('abbb') & Counter('bcc') Counter({'b': 1}) ''' if not isinstance(other, Counter): return NotImplemented _min = min result = Counter() if len(self) < len(other): self, other = other, self for elem in filter(self.__contains__, other): newcount = _min(self[elem], other[elem]) if newcount > 0: result[elem] = newcount return result ################################################################################ ### UserDict ################################################################################ class UserDict(MutableMapping): # Start by filling-out the abstract methods def __init__(self, dict=None, **kwargs): self.data = {} if dict is not None: self.update(dict) if len(kwargs): self.update(kwargs) def __len__(self): return len(self.data) def __getitem__(self, key): if key in self.data: return self.data[key] if hasattr(self.__class__, "__missing__"): return self.__class__.__missing__(self, key) raise KeyError(key) def __setitem__(self, key, item): self.data[key] = item def __delitem__(self, key): del self.data[key] def __iter__(self): return iter(self.data) # Modify __contains__ to work correctly when __missing__ is present def __contains__(self, key): return key in self.data # Now, add the methods in dicts but not in MutableMapping def __repr__(self): return repr(self.data) def copy(self): if self.__class__ is UserDict: return UserDict(self.data.copy()) import copy data = self.data try: self.data = {} c = copy.copy(self) finally: self.data = data c.update(self) return c @classmethod def fromkeys(cls, iterable, value=None): d = cls() for key in iterable: d[key] = value return d ################################################################################ ### UserList ################################################################################ class UserList(MutableSequence): """A more or less complete user-defined wrapper around list objects.""" def __init__(self, initlist=None): self.data = [] if initlist is not None: # XXX should this accept an arbitrary sequence? if type(initlist) == type(self.data): self.data[:] = initlist elif isinstance(initlist, UserList): self.data[:] = initlist.data[:] else: self.data = list(initlist) def __repr__(self): return repr(self.data) def __lt__(self, other): return self.data < self.__cast(other) def __le__(self, other): return self.data <= self.__cast(other) def __eq__(self, other): return self.data == self.__cast(other) def __ne__(self, other): return self.data != self.__cast(other) def __gt__(self, other): return self.data > self.__cast(other) def __ge__(self, other): return self.data >= self.__cast(other) def __cast(self, other): return other.data if isinstance(other, UserList) else other def __contains__(self, item): return item in self.data def __len__(self): return len(self.data) def __getitem__(self, i): return self.data[i] def __setitem__(self, i, item): self.data[i] = item def __delitem__(self, i): del self.data[i] def __add__(self, other): if isinstance(other, UserList): return self.__class__(self.data + other.data) elif isinstance(other, type(self.data)): return self.__class__(self.data + other) return self.__class__(self.data + list(other)) def __radd__(self, other): if isinstance(other, UserList): return self.__class__(other.data + self.data) elif isinstance(other, type(self.data)): return self.__class__(other + self.data) return self.__class__(list(other) + self.data) def __iadd__(self, other): if isinstance(other, UserList): self.data += other.data elif isinstance(other, type(self.data)): self.data += other else: self.data += list(other) return self def __mul__(self, n): return self.__class__(self.data*n) __rmul__ = __mul__ def __imul__(self, n): self.data *= n return self def append(self, item): self.data.append(item) def insert(self, i, item): self.data.insert(i, item) def pop(self, i=-1): return self.data.pop(i) def remove(self, item): self.data.remove(item) def count(self, item): return self.data.count(item) def index(self, item, *args): return self.data.index(item, *args) def reverse(self): self.data.reverse() def sort(self, *args, **kwds): self.data.sort(*args, **kwds) def extend(self, other): if isinstance(other, UserList): self.data.extend(other.data) else: self.data.extend(other) ################################################################################ ### UserString ################################################################################ class UserString(Sequence): def __init__(self, seq): if isinstance(seq, str): self.data = seq elif isinstance(seq, UserString): self.data = seq.data[:] else: self.data = str(seq) def __str__(self): return str(self.data) def __repr__(self): return repr(self.data) def __int__(self): return int(self.data) def __float__(self): return float(self.data) def __complex__(self): return complex(self.data) def __hash__(self): return hash(self.data) def __eq__(self, string): if isinstance(string, UserString): return self.data == string.data return self.data == string def __ne__(self, string): if isinstance(string, UserString): return self.data != string.data return self.data != string def __lt__(self, string): if isinstance(string, UserString): return self.data < string.data return self.data < string def __le__(self, string): if isinstance(string, UserString): return self.data <= string.data return self.data <= string def __gt__(self, string): if isinstance(string, UserString): return self.data > string.data return self.data > string def __ge__(self, string): if isinstance(string, UserString): return self.data >= string.data return self.data >= string def __contains__(self, char): if isinstance(char, UserString): char = char.data return char in self.data def __len__(self): return len(self.data) def __getitem__(self, index): return self.__class__(self.data[index]) def __add__(self, other): if isinstance(other, UserString): return self.__class__(self.data + other.data) elif isinstance(other, str): return self.__class__(self.data + other) return self.__class__(self.data + str(other)) def __radd__(self, other): if isinstance(other, str): return self.__class__(other + self.data) return self.__class__(str(other) + self.data) def __mul__(self, n): return self.__class__(self.data*n) __rmul__ = __mul__ def __mod__(self, args): return self.__class__(self.data % args) # the following methods are defined in alphabetical order: def capitalize(self): return self.__class__(self.data.capitalize()) def center(self, width, *args): return self.__class__(self.data.center(width, *args)) def count(self, sub, start=0, end=_sys.maxsize): if isinstance(sub, UserString): sub = sub.data return self.data.count(sub, start, end) def encode(self, encoding=None, errors=None): # XXX improve this? if encoding: if errors: return self.__class__(self.data.encode(encoding, errors)) return self.__class__(self.data.encode(encoding)) return self.__class__(self.data.encode()) def endswith(self, suffix, start=0, end=_sys.maxsize): return self.data.endswith(suffix, start, end) def expandtabs(self, tabsize=8): return self.__class__(self.data.expandtabs(tabsize)) def find(self, sub, start=0, end=_sys.maxsize): if isinstance(sub, UserString): sub = sub.data return self.data.find(sub, start, end) def format(self, *args, **kwds): return self.data.format(*args, **kwds) def index(self, sub, start=0, end=_sys.maxsize): return self.data.index(sub, start, end) def isalpha(self): return self.data.isalpha() def isalnum(self): return self.data.isalnum() def isdecimal(self): return self.data.isdecimal() def isdigit(self): return self.data.isdigit() def isidentifier(self): return self.data.isidentifier() def islower(self): return self.data.islower() def isnumeric(self): return self.data.isnumeric() def isspace(self): return self.data.isspace() def istitle(self): return self.data.istitle() def isupper(self): return self.data.isupper() def join(self, seq): return self.data.join(seq) def ljust(self, width, *args): return self.__class__(self.data.ljust(width, *args)) def lower(self): return self.__class__(self.data.lower()) def lstrip(self, chars=None): return self.__class__(self.data.lstrip(chars)) def partition(self, sep): return self.data.partition(sep) def replace(self, old, new, maxsplit=-1): if isinstance(old, UserString): old = old.data if isinstance(new, UserString): new = new.data return self.__class__(self.data.replace(old, new, maxsplit)) def rfind(self, sub, start=0, end=_sys.maxsize): return self.data.rfind(sub, start, end) def rindex(self, sub, start=0, end=_sys.maxsize): return self.data.rindex(sub, start, end) def rjust(self, width, *args): return self.__class__(self.data.rjust(width, *args)) def rpartition(self, sep): return self.data.rpartition(sep) def rstrip(self, chars=None): return self.__class__(self.data.rstrip(chars)) def split(self, sep=None, maxsplit=-1): return self.data.split(sep, maxsplit) def rsplit(self, sep=None, maxsplit=-1): return self.data.rsplit(sep, maxsplit) def splitlines(self, keepends=0): return self.data.splitlines(keepends) def startswith(self, prefix, start=0, end=_sys.maxsize): return self.data.startswith(prefix, start, end) def strip(self, chars=None): return self.__class__(self.data.strip(chars)) def swapcase(self): return self.__class__(self.data.swapcase()) def title(self): return self.__class__(self.data.title()) def translate(self, *args): return self.__class__(self.data.translate(*args)) def upper(self): return self.__class__(self.data.upper()) def zfill(self, width): return self.__class__(self.data.zfill(width)) ################################################################################ ### Simple tests ################################################################################ if __name__ == '__main__': # verify that instances can be pickled from pickle import loads, dumps Point = namedtuple('Point', 'x, y', True) p = Point(x=10, y=20) assert p == loads(dumps(p)) # test and demonstrate ability to override methods class Point(namedtuple('Point', 'x y')): __slots__ = () @property def hypot(self): return (self.x ** 2 + self.y ** 2) ** 0.5 def __str__(self): return 'Point: x=%6.3f y=%6.3f hypot=%6.3f' % (self.x, self.y, self.hypot) for p in Point(3, 4), Point(14, 5/7.): print (p) class Point(namedtuple('Point', 'x y')): 'Point class with optimized _make() and _replace() without error-checking' __slots__ = () _make = classmethod(tuple.__new__) def _replace(self, _map=map, **kwds): return self._make(_map(kwds.get, ('x', 'y'), self)) print(Point(11, 22)._replace(x=100)) Point3D = namedtuple('Point3D', Point._fields + ('z',)) print(Point3D.__doc__) import doctest TestResults = namedtuple('TestResults', 'failed attempted') print(TestResults(*doctest.testmod()))