From 3bbe3b7c822091caac90c00ee937848bc4de80eb Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 5 Oct 2023 00:47:41 +0800 Subject: gh-110222: Add support of PyStructSequence in copy.replace() (GH-110223) --- Lib/test/test_structseq.py | 78 ++++++++++++++++++++++ .../2023-10-02-15-07-28.gh-issue-110222.zl_oHh.rst | 2 + Objects/structseq.c | 76 ++++++++++++++++++++- 3 files changed, 155 insertions(+), 1 deletion(-) create mode 100644 Misc/NEWS.d/next/Library/2023-10-02-15-07-28.gh-issue-110222.zl_oHh.rst diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py index 2ef1316..6aec63e 100644 --- a/Lib/test/test_structseq.py +++ b/Lib/test/test_structseq.py @@ -264,6 +264,84 @@ class StructSeqTest(unittest.TestCase): self.assertEqual(os.stat_result.n_unnamed_fields, 3) self.assertEqual(os.stat_result.__match_args__, expected_args) + def test_copy_replace_all_fields_visible(self): + assert os.times_result.n_unnamed_fields == 0 + assert os.times_result.n_sequence_fields == os.times_result.n_fields + + t = os.times() + + # visible fields + self.assertEqual(copy.replace(t), t) + self.assertIsInstance(copy.replace(t), os.times_result) + self.assertEqual(copy.replace(t, user=1.5), (1.5, *t[1:])) + self.assertEqual(copy.replace(t, system=2.5), (t[0], 2.5, *t[2:])) + self.assertEqual(copy.replace(t, user=1.5, system=2.5), (1.5, 2.5, *t[2:])) + + # unknown fields + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, error=-1) + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, user=1, error=-1) + + def test_copy_replace_with_invisible_fields(self): + assert time.struct_time.n_unnamed_fields == 0 + assert time.struct_time.n_sequence_fields < time.struct_time.n_fields + + t = time.gmtime(0) + + # visible fields + t2 = copy.replace(t) + self.assertEqual(t2, (1970, 1, 1, 0, 0, 0, 3, 1, 0)) + self.assertIsInstance(t2, time.struct_time) + t3 = copy.replace(t, tm_year=2000) + self.assertEqual(t3, (2000, 1, 1, 0, 0, 0, 3, 1, 0)) + self.assertEqual(t3.tm_year, 2000) + t4 = copy.replace(t, tm_mon=2) + self.assertEqual(t4, (1970, 2, 1, 0, 0, 0, 3, 1, 0)) + self.assertEqual(t4.tm_mon, 2) + t5 = copy.replace(t, tm_year=2000, tm_mon=2) + self.assertEqual(t5, (2000, 2, 1, 0, 0, 0, 3, 1, 0)) + self.assertEqual(t5.tm_year, 2000) + self.assertEqual(t5.tm_mon, 2) + + # named invisible fields + self.assertTrue(hasattr(t, 'tm_zone'), f"{t} has no attribute 'tm_zone'") + with self.assertRaisesRegex(AttributeError, 'readonly attribute'): + t.tm_zone = 'some other zone' + self.assertEqual(t2.tm_zone, t.tm_zone) + self.assertEqual(t3.tm_zone, t.tm_zone) + self.assertEqual(t4.tm_zone, t.tm_zone) + t6 = copy.replace(t, tm_zone='some other zone') + self.assertEqual(t, t6) + self.assertEqual(t6.tm_zone, 'some other zone') + t7 = copy.replace(t, tm_year=2000, tm_zone='some other zone') + self.assertEqual(t7, (2000, 1, 1, 0, 0, 0, 3, 1, 0)) + self.assertEqual(t7.tm_year, 2000) + self.assertEqual(t7.tm_zone, 'some other zone') + + # unknown fields + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, error=2) + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, tm_year=2000, error=2) + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, tm_zone='some other zone', error=2) + + def test_copy_replace_with_unnamed_fields(self): + assert os.stat_result.n_unnamed_fields > 0 + + r = os.stat_result(range(os.stat_result.n_sequence_fields)) + + error_message = re.escape('__replace__() is not supported') + with self.assertRaisesRegex(TypeError, error_message): + copy.replace(r) + with self.assertRaisesRegex(TypeError, error_message): + copy.replace(r, st_mode=1) + with self.assertRaisesRegex(TypeError, error_message): + copy.replace(r, error=2) + with self.assertRaisesRegex(TypeError, error_message): + copy.replace(r, st_mode=1, error=2) + if __name__ == "__main__": unittest.main() diff --git a/Misc/NEWS.d/next/Library/2023-10-02-15-07-28.gh-issue-110222.zl_oHh.rst b/Misc/NEWS.d/next/Library/2023-10-02-15-07-28.gh-issue-110222.zl_oHh.rst new file mode 100644 index 0000000..fd2ecdf --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-10-02-15-07-28.gh-issue-110222.zl_oHh.rst @@ -0,0 +1,2 @@ +Add support of struct sequence objects in :func:`copy.replace`. +Patched by Xuehai Pan. diff --git a/Objects/structseq.c b/Objects/structseq.c index 2c98288..e4a4b45 100644 --- a/Objects/structseq.c +++ b/Objects/structseq.c @@ -8,6 +8,7 @@ */ #include "Python.h" +#include "pycore_dict.h" // _PyDict_Pop() #include "pycore_tuple.h" // _PyTuple_FromArray() #include "pycore_object.h" // _PyObject_GC_TRACK() @@ -380,9 +381,82 @@ error: return NULL; } + +static PyObject * +structseq_replace(PyStructSequence *self, PyObject *args, PyObject *kwargs) +{ + PyStructSequence *result = NULL; + Py_ssize_t n_fields, n_unnamed_fields, i; + + if (!_PyArg_NoPositional("__replace__", args)) { + return NULL; + } + + n_fields = REAL_SIZE(self); + if (n_fields < 0) { + return NULL; + } + n_unnamed_fields = UNNAMED_FIELDS(self); + if (n_unnamed_fields < 0) { + return NULL; + } + if (n_unnamed_fields > 0) { + PyErr_Format(PyExc_TypeError, + "__replace__() is not supported for %.500s " + "because it has unnamed field(s)", + Py_TYPE(self)->tp_name); + return NULL; + } + + result = (PyStructSequence *) PyStructSequence_New(Py_TYPE(self)); + if (!result) { + return NULL; + } + + if (kwargs != NULL) { + // We do not support types with unnamed fields, so we can iterate over + // i >= n_visible_fields case without slicing with (i - n_unnamed_fields). + for (i = 0; i < n_fields; ++i) { + PyObject *key = PyUnicode_FromString(Py_TYPE(self)->tp_members[i].name); + if (!key) { + goto error; + } + PyObject *ob = _PyDict_Pop(kwargs, key, self->ob_item[i]); + Py_DECREF(key); + if (!ob) { + goto error; + } + result->ob_item[i] = ob; + } + // Check if there are any unexpected fields. + if (PyDict_GET_SIZE(kwargs) > 0) { + PyObject *names = PyDict_Keys(kwargs); + if (names) { + PyErr_Format(PyExc_TypeError, "Got unexpected field name(s): %R", names); + Py_DECREF(names); + } + goto error; + } + } + else + { + // Just create a copy of the original. + for (i = 0; i < n_fields; ++i) { + result->ob_item[i] = Py_NewRef(self->ob_item[i]); + } + } + + return (PyObject *)result; + +error: + Py_DECREF(result); + return NULL; +} + static PyMethodDef structseq_methods[] = { {"__reduce__", (PyCFunction)structseq_reduce, METH_NOARGS, NULL}, - {NULL, NULL} + {"__replace__", _PyCFunction_CAST(structseq_replace), METH_VARARGS | METH_KEYWORDS, NULL}, + {NULL, NULL} // sentinel }; static Py_ssize_t -- cgit v0.12