diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/operator.py | 46 | ||||
-rw-r--r-- | Lib/test/test_operator.py | 109 |
2 files changed, 153 insertions, 2 deletions
diff --git a/Lib/operator.py b/Lib/operator.py index 856036d..0db51c1 100644 --- a/Lib/operator.py +++ b/Lib/operator.py @@ -231,10 +231,13 @@ class attrgetter: After h = attrgetter('name.first', 'name.last'), the call h(r) returns (r.name.first, r.name.last). """ + __slots__ = ('_attrs', '_call') + def __init__(self, attr, *attrs): if not attrs: if not isinstance(attr, str): raise TypeError('attribute name must be a string') + self._attrs = (attr,) names = attr.split('.') def func(obj): for name in names: @@ -242,7 +245,8 @@ class attrgetter: return obj self._call = func else: - getters = tuple(map(attrgetter, (attr,) + attrs)) + self._attrs = (attr,) + attrs + getters = tuple(map(attrgetter, self._attrs)) def func(obj): return tuple(getter(obj) for getter in getters) self._call = func @@ -250,19 +254,30 @@ class attrgetter: def __call__(self, obj): return self._call(obj) + def __repr__(self): + return '%s.%s(%s)' % (self.__class__.__module__, + self.__class__.__qualname__, + ', '.join(map(repr, self._attrs))) + + def __reduce__(self): + return self.__class__, self._attrs + class itemgetter: """ Return a callable object that fetches the given item(s) from its operand. After f = itemgetter(2), the call f(r) returns r[2]. After g = itemgetter(2, 5, 3), the call g(r) returns (r[2], r[5], r[3]) """ + __slots__ = ('_items', '_call') + def __init__(self, item, *items): if not items: + self._items = (item,) def func(obj): return obj[item] self._call = func else: - items = (item,) + items + self._items = items = (item,) + items def func(obj): return tuple(obj[i] for i in items) self._call = func @@ -270,6 +285,14 @@ class itemgetter: def __call__(self, obj): return self._call(obj) + def __repr__(self): + return '%s.%s(%s)' % (self.__class__.__module__, + self.__class__.__name__, + ', '.join(map(repr, self._items))) + + def __reduce__(self): + return self.__class__, self._items + class methodcaller: """ Return a callable object that calls the given method on its operand. @@ -277,6 +300,7 @@ class methodcaller: After g = methodcaller('name', 'date', foo=1), the call g(r) returns r.name('date', foo=1). """ + __slots__ = ('_name', '_args', '_kwargs') def __init__(*args, **kwargs): if len(args) < 2: @@ -284,12 +308,30 @@ class methodcaller: raise TypeError(msg) self = args[0] self._name = args[1] + if not isinstance(self._name, str): + raise TypeError('method name must be a string') self._args = args[2:] self._kwargs = kwargs def __call__(self, obj): return getattr(obj, self._name)(*self._args, **self._kwargs) + def __repr__(self): + args = [repr(self._name)] + args.extend(map(repr, self._args)) + args.extend('%s=%r' % (k, v) for k, v in self._kwargs.items()) + return '%s.%s(%s)' % (self.__class__.__module__, + self.__class__.__name__, + ', '.join(args)) + + def __reduce__(self): + if not self._kwargs: + return self.__class__, (self._name,) + self._args + else: + from functools import partial + return partial(self.__class__, self._name, **self._kwargs), self._args + + # In-place Operations *********************************************************# def iadd(a, b): diff --git a/Lib/test/test_operator.py b/Lib/test/test_operator.py index 1bd0391..ef9cf3e 100644 --- a/Lib/test/test_operator.py +++ b/Lib/test/test_operator.py @@ -1,4 +1,6 @@ import unittest +import pickle +import sys from test import support @@ -35,6 +37,9 @@ class Seq2(object): class OperatorTestCase: + def setUp(self): + sys.modules['operator'] = self.module + def test_lt(self): operator = self.module self.assertRaises(TypeError, operator.lt) @@ -396,6 +401,7 @@ class OperatorTestCase: def test_methodcaller(self): operator = self.module self.assertRaises(TypeError, operator.methodcaller) + self.assertRaises(TypeError, operator.methodcaller, 12) class A: def foo(self, *args, **kwds): return args[0] + args[1] @@ -491,5 +497,108 @@ class PyOperatorTestCase(OperatorTestCase, unittest.TestCase): class COperatorTestCase(OperatorTestCase, unittest.TestCase): module = c_operator + +class OperatorPickleTestCase: + def copy(self, obj, proto): + with support.swap_item(sys.modules, 'operator', self.module): + pickled = pickle.dumps(obj, proto) + with support.swap_item(sys.modules, 'operator', self.module2): + return pickle.loads(pickled) + + def test_attrgetter(self): + attrgetter = self.module.attrgetter + attrgetter = self.module.attrgetter + class A: + pass + a = A() + a.x = 'X' + a.y = 'Y' + a.z = 'Z' + a.t = A() + a.t.u = A() + a.t.u.v = 'V' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + f = attrgetter('x') + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + # multiple gets + f = attrgetter('x', 'y', 'z') + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + # recursive gets + f = attrgetter('t.u.v') + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + + def test_itemgetter(self): + itemgetter = self.module.itemgetter + a = 'ABCDE' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + f = itemgetter(2) + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + # multiple gets + f = itemgetter(2, 0, 4) + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + + def test_methodcaller(self): + methodcaller = self.module.methodcaller + class A: + def foo(self, *args, **kwds): + return args[0] + args[1] + def bar(self, f=42): + return f + def baz(*args, **kwds): + return kwds['name'], kwds['self'] + a = A() + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + f = methodcaller('bar') + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + # positional args + f = methodcaller('foo', 1, 2) + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + # keyword args + f = methodcaller('bar', f=5) + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + f = methodcaller('baz', self='eggs', name='spam') + f2 = self.copy(f, proto) + # Can't test repr consistently with multiple keyword args + self.assertEqual(f2(a), f(a)) + +class PyPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase): + module = py_operator + module2 = py_operator + +@unittest.skipUnless(c_operator, 'requires _operator') +class PyCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase): + module = py_operator + module2 = c_operator + +@unittest.skipUnless(c_operator, 'requires _operator') +class CPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase): + module = c_operator + module2 = py_operator + +@unittest.skipUnless(c_operator, 'requires _operator') +class CCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase): + module = c_operator + module2 = c_operator + + if __name__ == "__main__": unittest.main() |