From 166958b5df50fca05cb24be0152737edf575dbb9 Mon Sep 17 00:00:00 2001
From: Raymond Hettinger <python@rcn.com>
Date: Mon, 1 Dec 2003 13:18:39 +0000
Subject: As discussed on python-dev, added two extractor functions to the
 operator module.

---
 Doc/lib/liboperator.tex   |  33 +++++++
 Lib/test/test_operator.py |  39 ++++++++
 Misc/NEWS                 |   5 +
 Modules/operator.c        | 227 +++++++++++++++++++++++++++++++++++++++++++++-
 4 files changed, 302 insertions(+), 2 deletions(-)

diff --git a/Doc/lib/liboperator.tex b/Doc/lib/liboperator.tex
index e3aae60..ab68d0e 100644
--- a/Doc/lib/liboperator.tex
+++ b/Doc/lib/liboperator.tex
@@ -300,6 +300,39 @@ Example: Build a dictionary that maps the ordinals from \code{0} to
 \end{verbatim}
 
 
+The \module{operator} module also defines tools for generalized attribute
+and item lookups.  These are useful for making fast field extractors
+as arguments for \function{map()}, \method{list.sort()},
+\method{itertools.groupby()}, or other functions that expect a
+function argument.
+
+\begin{funcdesc}{attrgetter}{attr}
+Return a callable object that fetches \var{attr} from its operand.
+After, \samp{f=attrgetter('name')}, the call \samp{f(b)} returns
+\samp{b.name}.
+\versionadded{2.4}
+\end{funcdesc}
+    
+\begin{funcdesc}{itemgetter}{item}
+Return a callable object that fetches \var{item} from its operand.
+After, \samp{f=itemgetter(2)}, the call \samp{f(b)} returns
+\samp{b[2]}.
+\versionadded{2.4}
+\end{funcdesc}
+
+Examples:
+                
+\begin{verbatim}
+>>> from operator import *
+>>> inventory = [('apple', 3), ('banana', 2), ('pear', 5), ('orange', 1)]
+>>> getcount = itemgetter(1)
+>>> map(getcount, inventory)
+[3, 2, 5, 1]
+>>> list.sorted(inventory, key=getcount)
+[('orange', 1), ('banana', 2), ('apple', 3), ('pear', 5)]
+\end{verbatim}
+                
+
 \subsection{Mapping Operators to Functions \label{operator-map}}
 
 This table shows how abstract operations correspond to operator
diff --git a/Lib/test/test_operator.py b/Lib/test/test_operator.py
index 422a3cb..e3a67f0 100644
--- a/Lib/test/test_operator.py
+++ b/Lib/test/test_operator.py
@@ -227,6 +227,45 @@ class OperatorTestCase(unittest.TestCase):
         self.failIf(operator.is_not(a, b))
         self.failUnless(operator.is_not(a,c))
 
+    def test_attrgetter(self):
+        class A:
+            pass
+        a = A()
+        a.name = 'arthur'
+        f = operator.attrgetter('name')
+        self.assertEqual(f(a), 'arthur')
+        f = operator.attrgetter('rank')
+        self.assertRaises(AttributeError, f, a)
+        f = operator.attrgetter(2)
+        self.assertRaises(TypeError, f, a)
+        self.assertRaises(TypeError, operator.attrgetter)
+        self.assertRaises(TypeError, operator.attrgetter, 1, 2)
+
+    def test_itemgetter(self):
+        a = 'ABCDE'
+        f = operator.itemgetter(2)
+        self.assertEqual(f(a), 'C')
+        f = operator.itemgetter(10)
+        self.assertRaises(IndexError, f, a)
+
+        f = operator.itemgetter('name')
+        self.assertRaises(TypeError, f, a)
+        self.assertRaises(TypeError, operator.itemgetter)
+        self.assertRaises(TypeError, operator.itemgetter, 1, 2)
+
+        d = dict(key='val')
+        f = operator.itemgetter('key')
+        self.assertEqual(f(d), 'val')
+        f = operator.itemgetter('nonkey')
+        self.assertRaises(KeyError, f, d)
+
+        # example used in the docs
+        inventory = [('apple', 3), ('banana', 2), ('pear', 5), ('orange', 1)]
+        getcount = operator.itemgetter(1)
+        self.assertEqual(map(getcount, inventory), [3, 2, 5, 1])
+        self.assertEqual(list.sorted(inventory, key=getcount),
+            [('orange', 1), ('banana', 2), ('apple', 3), ('pear', 5)])
+
 def test_main():
     test_support.run_unittest(OperatorTestCase)
 
diff --git a/Misc/NEWS b/Misc/NEWS
index ce9d779..8f1c1a1 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -104,6 +104,11 @@ Core and builtins
 Extension modules
 -----------------
 
+- The operator module has two new functions, attrgetter() and
+  itemgetter() which are useful for creating fast data extractor
+  functions for map(), list.sort(), itertools.groupby(), and
+  other functions that expect a function argument.
+
 - socket.SHUT_{RD,WR,RDWR} was added.
 
 - os.getsid was added.
diff --git a/Modules/operator.c b/Modules/operator.c
index 7638fb8..d8e2a54 100644
--- a/Modules/operator.c
+++ b/Modules/operator.c
@@ -252,13 +252,236 @@ spam2(ge,__ge__, "ge(a, b) -- Same as a>=b.")
 
 };
 
+/* itemgetter object **********************************************************/
 
+typedef struct {
+	PyObject_HEAD
+	PyObject *item;
+} itemgetterobject;
+
+static PyTypeObject itemgetter_type;
+
+static PyObject *
+itemgetter_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+	itemgetterobject *ig;
+	PyObject *item;
+
+	if (!PyArg_UnpackTuple(args, "itemgetter", 1, 1, &item))
+		return NULL;
+
+	/* create itemgetterobject structure */
+	ig = PyObject_GC_New(itemgetterobject, &itemgetter_type);
+	if (ig == NULL) 
+		return NULL;	
+	
+	Py_INCREF(item);
+	ig->item = item;
+
+	PyObject_GC_Track(ig);
+	return (PyObject *)ig;
+}
+
+static void
+itemgetter_dealloc(itemgetterobject *ig)
+{
+	PyObject_GC_UnTrack(ig);
+	Py_XDECREF(ig->item);
+	PyObject_GC_Del(ig);
+}
+
+static int
+itemgetter_traverse(itemgetterobject *ig, visitproc visit, void *arg)
+{
+	if (ig->item)
+		return visit(ig->item, arg);
+	return 0;
+}
+
+static PyObject *
+itemgetter_call(itemgetterobject *ig, PyObject *args, PyObject *kw)
+{
+	PyObject * obj;
+
+	if (!PyArg_UnpackTuple(args, "itemgetter", 1, 1, &obj))
+		return NULL;
+	return PyObject_GetItem(obj, ig->item);
+}
+
+PyDoc_STRVAR(itemgetter_doc,
+"itemgetter(item) --> itemgetter object\n\
+\n\
+Return a callable object that fetches the given item from its operand.\n\
+After, f=itemgetter(2), the call f(b) returns b[2].");
+
+static PyTypeObject itemgetter_type = {
+	PyObject_HEAD_INIT(NULL)
+	0,				/* ob_size */
+	"itertools.itemgetter",		/* tp_name */
+	sizeof(itemgetterobject),	/* tp_basicsize */
+	0,				/* tp_itemsize */
+	/* methods */
+	(destructor)itemgetter_dealloc,	/* tp_dealloc */
+	0,				/* tp_print */
+	0,				/* tp_getattr */
+	0,				/* tp_setattr */
+	0,				/* tp_compare */
+	0,				/* tp_repr */
+	0,				/* tp_as_number */
+	0,				/* tp_as_sequence */
+	0,				/* tp_as_mapping */
+	0,				/* tp_hash */
+	(ternaryfunc)itemgetter_call,	/* tp_call */
+	0,				/* tp_str */
+	PyObject_GenericGetAttr,	/* tp_getattro */
+	0,				/* tp_setattro */
+	0,				/* tp_as_buffer */
+	Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,	/* tp_flags */
+	itemgetter_doc,			/* tp_doc */
+	(traverseproc)itemgetter_traverse,	/* tp_traverse */
+	0,				/* tp_clear */
+	0,				/* tp_richcompare */
+	0,				/* tp_weaklistoffset */
+	0,				/* tp_iter */
+	0,				/* tp_iternext */
+	0,				/* tp_methods */
+	0,				/* tp_members */
+	0,				/* tp_getset */
+	0,				/* tp_base */
+	0,				/* tp_dict */
+	0,				/* tp_descr_get */
+	0,				/* tp_descr_set */
+	0,				/* tp_dictoffset */
+	0,				/* tp_init */
+	0,				/* tp_alloc */
+	itemgetter_new,			/* tp_new */
+	0,				/* tp_free */
+};
+
+
+/* attrgetter object **********************************************************/
+
+typedef struct {
+	PyObject_HEAD
+	PyObject *attr;
+} attrgetterobject;
+
+static PyTypeObject attrgetter_type;
+
+static PyObject *
+attrgetter_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+	attrgetterobject *ag;
+	PyObject *attr;
+
+	if (!PyArg_UnpackTuple(args, "attrgetter", 1, 1, &attr))
+		return NULL;
+
+	/* create attrgetterobject structure */
+	ag = PyObject_GC_New(attrgetterobject, &attrgetter_type);
+	if (ag == NULL) 
+		return NULL;	
+	
+	Py_INCREF(attr);
+	ag->attr = attr;
+
+	PyObject_GC_Track(ag);
+	return (PyObject *)ag;
+}
+
+static void
+attrgetter_dealloc(attrgetterobject *ag)
+{
+	PyObject_GC_UnTrack(ag);
+	Py_XDECREF(ag->attr);
+	PyObject_GC_Del(ag);
+}
+
+static int
+attrgetter_traverse(attrgetterobject *ag, visitproc visit, void *arg)
+{
+	if (ag->attr)
+		return visit(ag->attr, arg);
+	return 0;
+}
+
+static PyObject *
+attrgetter_call(attrgetterobject *ag, PyObject *args, PyObject *kw)
+{
+	PyObject * obj;
+
+	if (!PyArg_UnpackTuple(args, "attrgetter", 1, 1, &obj))
+		return NULL;
+	return PyObject_GetAttr(obj, ag->attr);
+}
+
+PyDoc_STRVAR(attrgetter_doc,
+"attrgetter(attr) --> attrgetter object\n\
+\n\
+Return a callable object that fetches the given attribute from its operand.\n\
+After, f=attrgetter('name'), the call f(b) returns b.name.");
+
+static PyTypeObject attrgetter_type = {
+	PyObject_HEAD_INIT(NULL)
+	0,				/* ob_size */
+	"itertools.attrgetter",		/* tp_name */
+	sizeof(attrgetterobject),	/* tp_basicsize */
+	0,				/* tp_itemsize */
+	/* methods */
+	(destructor)attrgetter_dealloc,	/* tp_dealloc */
+	0,				/* tp_print */
+	0,				/* tp_getattr */
+	0,				/* tp_setattr */
+	0,				/* tp_compare */
+	0,				/* tp_repr */
+	0,				/* tp_as_number */
+	0,				/* tp_as_sequence */
+	0,				/* tp_as_mapping */
+	0,				/* tp_hash */
+	(ternaryfunc)attrgetter_call,	/* tp_call */
+	0,				/* tp_str */
+	PyObject_GenericGetAttr,	/* tp_getattro */
+	0,				/* tp_setattro */
+	0,				/* tp_as_buffer */
+	Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,	/* tp_flags */
+	attrgetter_doc,			/* tp_doc */
+	(traverseproc)attrgetter_traverse,	/* tp_traverse */
+	0,				/* tp_clear */
+	0,				/* tp_richcompare */
+	0,				/* tp_weaklistoffset */
+	0,				/* tp_iter */
+	0,				/* tp_iternext */
+	0,				/* tp_methods */
+	0,				/* tp_members */
+	0,				/* tp_getset */
+	0,				/* tp_base */
+	0,				/* tp_dict */
+	0,				/* tp_descr_get */
+	0,				/* tp_descr_set */
+	0,				/* tp_dictoffset */
+	0,				/* tp_init */
+	0,				/* tp_alloc */
+	attrgetter_new,			/* tp_new */
+	0,				/* tp_free */
+};
 /* Initialization function for the module (*must* be called initoperator) */
 
 PyMODINIT_FUNC
 initoperator(void)
 {
-        /* Create the module and add the functions */
-        Py_InitModule4("operator", operator_methods, operator_doc,
+	PyObject *m;
+        
+	/* Create the module and add the functions */
+        m = Py_InitModule4("operator", operator_methods, operator_doc,
 		       (PyObject*)NULL, PYTHON_API_VERSION);
+
+	if (PyType_Ready(&itemgetter_type) < 0)
+		return;
+	Py_INCREF(&itemgetter_type);
+	PyModule_AddObject(m, "itemgetter", (PyObject *)&itemgetter_type);
+
+	if (PyType_Ready(&attrgetter_type) < 0)
+		return;
+	Py_INCREF(&attrgetter_type);
+	PyModule_AddObject(m, "attrgetter", (PyObject *)&attrgetter_type);
 }
-- 
cgit v0.12