diff options
Diffstat (limited to 'Lib/test/audit-tests.py')
-rw-r--r-- | Lib/test/audit-tests.py | 269 |
1 files changed, 269 insertions, 0 deletions
diff --git a/Lib/test/audit-tests.py b/Lib/test/audit-tests.py new file mode 100644 index 0000000..7a7725f --- /dev/null +++ b/Lib/test/audit-tests.py @@ -0,0 +1,269 @@ +"""This script contains the actual auditing tests. + +It should not be imported directly, but should be run by the test_audit +module with arguments identifying each test. + +""" + +import contextlib +import sys + + +class TestHook: + """Used in standard hook tests to collect any logged events. + + Should be used in a with block to ensure that it has no impact + after the test completes. + """ + + def __init__(self, raise_on_events=None, exc_type=RuntimeError): + self.raise_on_events = raise_on_events or () + self.exc_type = exc_type + self.seen = [] + self.closed = False + + def __enter__(self, *a): + sys.addaudithook(self) + return self + + def __exit__(self, *a): + self.close() + + def close(self): + self.closed = True + + @property + def seen_events(self): + return [i[0] for i in self.seen] + + def __call__(self, event, args): + if self.closed: + return + self.seen.append((event, args)) + if event in self.raise_on_events: + raise self.exc_type("saw event " + event) + + +class TestFinalizeHook: + """Used in the test_finalize_hooks function to ensure that hooks + are correctly cleaned up, that they are notified about the cleanup, + and are unable to prevent it. + """ + + def __init__(self): + print("Created", id(self), file=sys.stdout, flush=True) + + def __call__(self, event, args): + # Avoid recursion when we call id() below + if event == "builtins.id": + return + + print(event, id(self), file=sys.stdout, flush=True) + + if event == "cpython._PySys_ClearAuditHooks": + raise RuntimeError("Should be ignored") + elif event == "cpython.PyInterpreterState_Clear": + raise RuntimeError("Should be ignored") + + +# Simple helpers, since we are not in unittest here +def assertEqual(x, y): + if x != y: + raise AssertionError(f"{x!r} should equal {y!r}") + + +def assertIn(el, series): + if el not in series: + raise AssertionError(f"{el!r} should be in {series!r}") + + +def assertNotIn(el, series): + if el in series: + raise AssertionError(f"{el!r} should not be in {series!r}") + + +def assertSequenceEqual(x, y): + if len(x) != len(y): + raise AssertionError(f"{x!r} should equal {y!r}") + if any(ix != iy for ix, iy in zip(x, y)): + raise AssertionError(f"{x!r} should equal {y!r}") + + +@contextlib.contextmanager +def assertRaises(ex_type): + try: + yield + assert False, f"expected {ex_type}" + except BaseException as ex: + if isinstance(ex, AssertionError): + raise + assert type(ex) is ex_type, f"{ex} should be {ex_type}" + + +def test_basic(): + with TestHook() as hook: + sys.audit("test_event", 1, 2, 3) + assertEqual(hook.seen[0][0], "test_event") + assertEqual(hook.seen[0][1], (1, 2, 3)) + + +def test_block_add_hook(): + # Raising an exception should prevent a new hook from being added, + # but will not propagate out. + with TestHook(raise_on_events="sys.addaudithook") as hook1: + with TestHook() as hook2: + sys.audit("test_event") + assertIn("test_event", hook1.seen_events) + assertNotIn("test_event", hook2.seen_events) + + +def test_block_add_hook_baseexception(): + # Raising BaseException will propagate out when adding a hook + with assertRaises(BaseException): + with TestHook( + raise_on_events="sys.addaudithook", exc_type=BaseException + ) as hook1: + # Adding this next hook should raise BaseException + with TestHook() as hook2: + pass + + +def test_finalize_hooks(): + sys.addaudithook(TestFinalizeHook()) + + +def test_pickle(): + import pickle + + class PicklePrint: + def __reduce_ex__(self, p): + return str, ("Pwned!",) + + payload_1 = pickle.dumps(PicklePrint()) + payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3)) + + # Before we add the hook, ensure our malicious pickle loads + assertEqual("Pwned!", pickle.loads(payload_1)) + + with TestHook(raise_on_events="pickle.find_class") as hook: + with assertRaises(RuntimeError): + # With the hook enabled, loading globals is not allowed + pickle.loads(payload_1) + # pickles with no globals are okay + pickle.loads(payload_2) + + +def test_monkeypatch(): + class A: + pass + + class B: + pass + + class C(A): + pass + + a = A() + + with TestHook() as hook: + # Catch name changes + C.__name__ = "X" + # Catch type changes + C.__bases__ = (B,) + # Ensure bypassing __setattr__ is still caught + type.__dict__["__bases__"].__set__(C, (B,)) + # Catch attribute replacement + C.__init__ = B.__init__ + # Catch attribute addition + C.new_attr = 123 + # Catch class changes + a.__class__ = B + + actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"] + assertSequenceEqual( + [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual + ) + + +def test_open(): + # SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open() + try: + import ssl + + load_dh_params = ssl.create_default_context().load_dh_params + except ImportError: + load_dh_params = None + + # Try a range of "open" functions. + # All of them should fail + with TestHook(raise_on_events={"open"}) as hook: + for fn, *args in [ + (open, sys.argv[2], "r"), + (open, sys.executable, "rb"), + (open, 3, "wb"), + (open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1), + (load_dh_params, sys.argv[2]), + ]: + if not fn: + continue + with assertRaises(RuntimeError): + fn(*args) + + actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]] + actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]] + assertSequenceEqual( + [ + i + for i in [ + (sys.argv[2], "r"), + (sys.executable, "r"), + (3, "w"), + (sys.argv[2], "w"), + (sys.argv[2], "rb") if load_dh_params else None, + ] + if i is not None + ], + actual_mode, + ) + assertSequenceEqual([], actual_flag) + + +def test_cantrace(): + traced = [] + + def trace(frame, event, *args): + if frame.f_code == TestHook.__call__.__code__: + traced.append(event) + + old = sys.settrace(trace) + try: + with TestHook() as hook: + # No traced call + eval("1") + + # No traced call + hook.__cantrace__ = False + eval("2") + + # One traced call + hook.__cantrace__ = True + eval("3") + + # Two traced calls (writing to private member, eval) + hook.__cantrace__ = 1 + eval("4") + + # One traced call (writing to private member) + hook.__cantrace__ = 0 + finally: + sys.settrace(old) + + assertSequenceEqual(["call"] * 4, traced) + + +if __name__ == "__main__": + from test.libregrtest.setup import suppress_msvcrt_asserts + suppress_msvcrt_asserts(False) + + test = sys.argv[1] + globals()[test]() |