From 1f5fe9962f768c8bfd4ed06a22532d31d3424dc9 Mon Sep 17 00:00:00 2001 From: "Miss Islington (bot)" <31488909+miss-islington@users.noreply.github.com> Date: Fri, 11 Feb 2022 12:44:17 -0800 Subject: bpo-46615: Don't crash when set operations mutate the sets (GH-31120) Ensure strong references are acquired whenever using `set_next()`. Added randomized test cases for `__eq__` methods that sometimes mutate sets when called. (cherry picked from commit 4a66615ba736f84eadf9456bfd5d32a94cccf117) Co-authored-by: Dennis Sweeney <36520290+sweeneyde@users.noreply.github.com> --- Lib/test/test_set.py | 186 +++++++++++++++++++++ .../2022-02-04-04-33-18.bpo-46615.puArY9.rst | 1 + Objects/setobject.c | 47 +++++- 3 files changed, 226 insertions(+), 8 deletions(-) create mode 100644 Misc/NEWS.d/next/Core and Builtins/2022-02-04-04-33-18.bpo-46615.puArY9.rst diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index 29bb39d..824eddb 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -1766,6 +1766,192 @@ class TestWeirdBugs(unittest.TestCase): s = {0} s.update(other) + +class TestOperationsMutating: + """Regression test for bpo-46615""" + + constructor1 = None + constructor2 = None + + def make_sets_of_bad_objects(self): + class Bad: + def __eq__(self, other): + if not enabled: + return False + if randrange(20) == 0: + set1.clear() + if randrange(20) == 0: + set2.clear() + return bool(randrange(2)) + def __hash__(self): + return randrange(2) + # Don't behave poorly during construction. + enabled = False + set1 = self.constructor1(Bad() for _ in range(randrange(50))) + set2 = self.constructor2(Bad() for _ in range(randrange(50))) + # Now start behaving poorly + enabled = True + return set1, set2 + + def check_set_op_does_not_crash(self, function): + for _ in range(100): + set1, set2 = self.make_sets_of_bad_objects() + try: + function(set1, set2) + except RuntimeError as e: + # Just make sure we don't crash here. + self.assertIn("changed size during iteration", str(e)) + + +class TestBinaryOpsMutating(TestOperationsMutating): + + def test_eq_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a == b) + + def test_ne_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a != b) + + def test_lt_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a < b) + + def test_le_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a <= b) + + def test_gt_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a > b) + + def test_ge_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a >= b) + + def test_and_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a & b) + + def test_or_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a | b) + + def test_sub_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a - b) + + def test_xor_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a ^ b) + + def test_iadd_with_mutation(self): + def f(a, b): + a &= b + self.check_set_op_does_not_crash(f) + + def test_ior_with_mutation(self): + def f(a, b): + a |= b + self.check_set_op_does_not_crash(f) + + def test_isub_with_mutation(self): + def f(a, b): + a -= b + self.check_set_op_does_not_crash(f) + + def test_ixor_with_mutation(self): + def f(a, b): + a ^= b + self.check_set_op_does_not_crash(f) + + def test_iteration_with_mutation(self): + def f1(a, b): + for x in a: + pass + for y in b: + pass + def f2(a, b): + for y in b: + pass + for x in a: + pass + def f3(a, b): + for x, y in zip(a, b): + pass + self.check_set_op_does_not_crash(f1) + self.check_set_op_does_not_crash(f2) + self.check_set_op_does_not_crash(f3) + + +class TestBinaryOpsMutating_Set_Set(TestBinaryOpsMutating, unittest.TestCase): + constructor1 = set + constructor2 = set + +class TestBinaryOpsMutating_Subclass_Subclass(TestBinaryOpsMutating, unittest.TestCase): + constructor1 = SetSubclass + constructor2 = SetSubclass + +class TestBinaryOpsMutating_Set_Subclass(TestBinaryOpsMutating, unittest.TestCase): + constructor1 = set + constructor2 = SetSubclass + +class TestBinaryOpsMutating_Subclass_Set(TestBinaryOpsMutating, unittest.TestCase): + constructor1 = SetSubclass + constructor2 = set + + +class TestMethodsMutating(TestOperationsMutating): + + def test_issubset_with_mutation(self): + self.check_set_op_does_not_crash(set.issubset) + + def test_issuperset_with_mutation(self): + self.check_set_op_does_not_crash(set.issuperset) + + def test_intersection_with_mutation(self): + self.check_set_op_does_not_crash(set.intersection) + + def test_union_with_mutation(self): + self.check_set_op_does_not_crash(set.union) + + def test_difference_with_mutation(self): + self.check_set_op_does_not_crash(set.difference) + + def test_symmetric_difference_with_mutation(self): + self.check_set_op_does_not_crash(set.symmetric_difference) + + def test_isdisjoint_with_mutation(self): + self.check_set_op_does_not_crash(set.isdisjoint) + + def test_difference_update_with_mutation(self): + self.check_set_op_does_not_crash(set.difference_update) + + def test_intersection_update_with_mutation(self): + self.check_set_op_does_not_crash(set.intersection_update) + + def test_symmetric_difference_update_with_mutation(self): + self.check_set_op_does_not_crash(set.symmetric_difference_update) + + def test_update_with_mutation(self): + self.check_set_op_does_not_crash(set.update) + + +class TestMethodsMutating_Set_Set(TestMethodsMutating, unittest.TestCase): + constructor1 = set + constructor2 = set + +class TestMethodsMutating_Subclass_Subclass(TestMethodsMutating, unittest.TestCase): + constructor1 = SetSubclass + constructor2 = SetSubclass + +class TestMethodsMutating_Set_Subclass(TestMethodsMutating, unittest.TestCase): + constructor1 = set + constructor2 = SetSubclass + +class TestMethodsMutating_Subclass_Set(TestMethodsMutating, unittest.TestCase): + constructor1 = SetSubclass + constructor2 = set + +class TestMethodsMutating_Set_Dict(TestMethodsMutating, unittest.TestCase): + constructor1 = set + constructor2 = dict.fromkeys + +class TestMethodsMutating_Set_List(TestMethodsMutating, unittest.TestCase): + constructor1 = set + constructor2 = list + + # Application tests (based on David Eppstein's graph recipes ==================================== def powerset(U): diff --git a/Misc/NEWS.d/next/Core and Builtins/2022-02-04-04-33-18.bpo-46615.puArY9.rst b/Misc/NEWS.d/next/Core and Builtins/2022-02-04-04-33-18.bpo-46615.puArY9.rst new file mode 100644 index 0000000..6dee92a --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2022-02-04-04-33-18.bpo-46615.puArY9.rst @@ -0,0 +1 @@ +When iterating over sets internally in ``setobject.c``, acquire strong references to the resulting items from the set. This prevents crashes in corner-cases of various set operations where the set gets mutated. diff --git a/Objects/setobject.c b/Objects/setobject.c index 6524963..e8ba32e 100644 --- a/Objects/setobject.c +++ b/Objects/setobject.c @@ -1204,17 +1204,21 @@ set_intersection(PySetObject *so, PyObject *other) while (set_next((PySetObject *)other, &pos, &entry)) { key = entry->key; hash = entry->hash; + Py_INCREF(key); rv = set_contains_entry(so, key, hash); if (rv < 0) { Py_DECREF(result); + Py_DECREF(key); return NULL; } if (rv) { if (set_add_entry(result, key, hash)) { Py_DECREF(result); + Py_DECREF(key); return NULL; } } + Py_DECREF(key); } return (PyObject *)result; } @@ -1354,11 +1358,16 @@ set_isdisjoint(PySetObject *so, PyObject *other) other = tmp; } while (set_next((PySetObject *)other, &pos, &entry)) { - rv = set_contains_entry(so, entry->key, entry->hash); - if (rv < 0) + PyObject *key = entry->key; + Py_INCREF(key); + rv = set_contains_entry(so, key, entry->hash); + Py_DECREF(key); + if (rv < 0) { return NULL; - if (rv) + } + if (rv) { Py_RETURN_FALSE; + } } Py_RETURN_TRUE; } @@ -1417,11 +1426,16 @@ set_difference_update_internal(PySetObject *so, PyObject *other) Py_INCREF(other); } - while (set_next((PySetObject *)other, &pos, &entry)) - if (set_discard_entry(so, entry->key, entry->hash) < 0) { + while (set_next((PySetObject *)other, &pos, &entry)) { + PyObject *key = entry->key; + Py_INCREF(key); + if (set_discard_entry(so, key, entry->hash) < 0) { Py_DECREF(other); + Py_DECREF(key); return -1; } + Py_DECREF(key); + } Py_DECREF(other); } else { @@ -1512,17 +1526,21 @@ set_difference(PySetObject *so, PyObject *other) while (set_next(so, &pos, &entry)) { key = entry->key; hash = entry->hash; + Py_INCREF(key); rv = _PyDict_Contains_KnownHash(other, key, hash); if (rv < 0) { Py_DECREF(result); + Py_DECREF(key); return NULL; } if (!rv) { if (set_add_entry((PySetObject *)result, key, hash)) { Py_DECREF(result); + Py_DECREF(key); return NULL; } } + Py_DECREF(key); } return result; } @@ -1531,17 +1549,21 @@ set_difference(PySetObject *so, PyObject *other) while (set_next(so, &pos, &entry)) { key = entry->key; hash = entry->hash; + Py_INCREF(key); rv = set_contains_entry((PySetObject *)other, key, hash); if (rv < 0) { Py_DECREF(result); + Py_DECREF(key); return NULL; } if (!rv) { if (set_add_entry((PySetObject *)result, key, hash)) { Py_DECREF(result); + Py_DECREF(key); return NULL; } } + Py_DECREF(key); } return result; } @@ -1638,17 +1660,21 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other) while (set_next(otherset, &pos, &entry)) { key = entry->key; hash = entry->hash; + Py_INCREF(key); rv = set_discard_entry(so, key, hash); if (rv < 0) { Py_DECREF(otherset); + Py_DECREF(key); return NULL; } if (rv == DISCARD_NOTFOUND) { if (set_add_entry(so, key, hash)) { Py_DECREF(otherset); + Py_DECREF(key); return NULL; } } + Py_DECREF(key); } Py_DECREF(otherset); Py_RETURN_NONE; @@ -1723,11 +1749,16 @@ set_issubset(PySetObject *so, PyObject *other) Py_RETURN_FALSE; while (set_next(so, &pos, &entry)) { - rv = set_contains_entry((PySetObject *)other, entry->key, entry->hash); - if (rv < 0) + PyObject *key = entry->key; + Py_INCREF(key); + rv = set_contains_entry((PySetObject *)other, key, entry->hash); + Py_DECREF(key); + if (rv < 0) { return NULL; - if (!rv) + } + if (!rv) { Py_RETURN_FALSE; + } } Py_RETURN_TRUE; } -- cgit v0.12