diff options
author | Eric Snow <ericsnowcurrently@gmail.com> | 2018-05-16 19:04:57 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-05-16 19:04:57 (GMT) |
commit | 6d2cd9036c0ab78a83de43d1511befb7a7fc0ade (patch) | |
tree | 44b1a52ab32b914553b17196f4b6b75fdbf03d06 | |
parent | 55e53c309359327e54eb74b101c5a3240ea9cd45 (diff) | |
download | cpython-6d2cd9036c0ab78a83de43d1511befb7a7fc0ade.zip cpython-6d2cd9036c0ab78a83de43d1511befb7a7fc0ade.tar.gz cpython-6d2cd9036c0ab78a83de43d1511befb7a7fc0ade.tar.bz2 |
bpo-32604: Improve subinterpreter tests. (#6914)
Add more tests for subinterpreters. This patch also fixes a few small defects in the channel implementation.
-rw-r--r-- | Lib/test/test__xxsubinterpreters.py | 1194 | ||||
-rw-r--r-- | Modules/_xxsubinterpretersmodule.c | 250 | ||||
-rw-r--r-- | Python/pystate.c | 80 |
3 files changed, 1249 insertions, 275 deletions
diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py index 4ef7771..118f2e4 100644 --- a/Lib/test/test__xxsubinterpreters.py +++ b/Lib/test/test__xxsubinterpreters.py @@ -1,6 +1,9 @@ +from collections import namedtuple import contextlib +import itertools import os import pickle +import sys from textwrap import dedent, indent import threading import time @@ -12,23 +15,32 @@ from test.support import script_helper interpreters = support.import_module('_xxsubinterpreters') +################################## +# helpers + +def powerset(*sets): + return itertools.chain.from_iterable( + combinations(sets, r) + for r in range(len(sets)+1)) + + def _captured_script(script): r, w = os.pipe() indented = script.replace('\n', '\n ') wrapped = dedent(f""" import contextlib - with open({w}, 'w') as chan: - with contextlib.redirect_stdout(chan): + with open({w}, 'w') as spipe: + with contextlib.redirect_stdout(spipe): {indented} """) return wrapped, open(r) def _run_output(interp, request, shared=None): - script, chan = _captured_script(request) - with chan: + script, rpipe = _captured_script(request) + with rpipe: interpreters.run_string(interp, script, shared) - return chan.read() + return rpipe.read() @contextlib.contextmanager @@ -37,8 +49,8 @@ def _running(interp): def run(): interpreters.run_string(interp, dedent(f""" # wait for "signal" - with open({r}) as chan: - chan.read() + with open({r}) as rpipe: + rpipe.read() """)) t = threading.Thread(target=run) @@ -46,11 +58,248 @@ def _running(interp): yield - with open(w, 'w') as chan: - chan.write('done') + with open(w, 'w') as spipe: + spipe.write('done') t.join() +#@contextmanager +#def run_threaded(id, source, **shared): +# def run(): +# run_interp(id, source, **shared) +# t = threading.Thread(target=run) +# t.start() +# yield +# t.join() + + +def run_interp(id, source, **shared): + _run_interp(id, source, shared) + + +def _run_interp(id, source, shared, _mainns={}): + source = dedent(source) + main = interpreters.get_main() + if main == id: + if interpreters.get_current() != main: + raise RuntimeError + # XXX Run a func? + exec(source, _mainns) + else: + interpreters.run_string(id, source, shared) + + +def run_interp_threaded(id, source, **shared): + def run(): + _run(id, source, shared) + t = threading.Thread(target=run) + t.start() + t.join() + + +class Interpreter(namedtuple('Interpreter', 'name id')): + + @classmethod + def from_raw(cls, raw): + if isinstance(raw, cls): + return raw + elif isinstance(raw, str): + return cls(raw) + else: + raise NotImplementedError + + def __new__(cls, name=None, id=None): + main = interpreters.get_main() + if id == main: + if not name: + name = 'main' + elif name != 'main': + raise ValueError( + 'name mismatch (expected "main", got "{}")'.format(name)) + id = main + elif id is not None: + if not name: + name = 'interp' + elif name == 'main': + raise ValueError('name mismatch (unexpected "main")') + if not isinstance(id, interpreters.InterpreterID): + id = interpreters.InterpreterID(id) + elif not name or name == 'main': + name = 'main' + id = main + else: + id = interpreters.create() + self = super().__new__(cls, name, id) + return self + + +# XXX expect_channel_closed() is unnecessary once we improve exc propagation. + +@contextlib.contextmanager +def expect_channel_closed(): + try: + yield + except interpreters.ChannelClosedError: + pass + else: + assert False, 'channel not closed' + + +class ChannelAction(namedtuple('ChannelAction', 'action end interp')): + + def __new__(cls, action, end=None, interp=None): + if not end: + end = 'both' + if not interp: + interp = 'main' + self = super().__new__(cls, action, end, interp) + return self + + def __init__(self, *args, **kwargs): + if self.action == 'use': + if self.end not in ('same', 'opposite', 'send', 'recv'): + raise ValueError(self.end) + elif self.action in ('close', 'force-close'): + if self.end not in ('both', 'same', 'opposite', 'send', 'recv'): + raise ValueError(self.end) + else: + raise ValueError(self.action) + if self.interp not in ('main', 'same', 'other', 'extra'): + raise ValueError(self.interp) + + def resolve_end(self, end): + if self.end == 'same': + return end + elif self.end == 'opposite': + return 'recv' if end == 'send' else 'send' + else: + return self.end + + def resolve_interp(self, interp, other, extra): + if self.interp == 'same': + return interp + elif self.interp == 'other': + if other is None: + raise RuntimeError + return other + elif self.interp == 'extra': + if extra is None: + raise RuntimeError + return extra + elif self.interp == 'main': + if interp.name == 'main': + return interp + elif other and other.name == 'main': + return other + else: + raise RuntimeError + # Per __init__(), there aren't any others. + + +class ChannelState(namedtuple('ChannelState', 'pending closed')): + + def __new__(cls, pending=0, *, closed=False): + self = super().__new__(cls, pending, closed) + return self + + def incr(self): + return type(self)(self.pending + 1, closed=self.closed) + + def decr(self): + return type(self)(self.pending - 1, closed=self.closed) + + def close(self, *, force=True): + if self.closed: + if not force or self.pending == 0: + return self + return type(self)(0 if force else self.pending, closed=True) + + +def run_action(cid, action, end, state, *, hideclosed=True): + if state.closed: + if action == 'use' and end == 'recv' and state.pending: + expectfail = False + else: + expectfail = True + else: + expectfail = False + + try: + result = _run_action(cid, action, end, state) + except interpreters.ChannelClosedError: + if not hideclosed and not expectfail: + raise + result = state.close() + else: + if expectfail: + raise ... # XXX + return result + + +def _run_action(cid, action, end, state): + if action == 'use': + if end == 'send': + interpreters.channel_send(cid, b'spam') + return state.incr() + elif end == 'recv': + if not state.pending: + try: + interpreters.channel_recv(cid) + except interpreters.ChannelEmptyError: + return state + else: + raise Exception('expected ChannelEmptyError') + else: + interpreters.channel_recv(cid) + return state.decr() + else: + raise ValueError(end) + elif action == 'close': + kwargs = {} + if end in ('recv', 'send'): + kwargs[end] = True + interpreters.channel_close(cid, **kwargs) + return state.close() + elif action == 'force-close': + kwargs = { + 'force': True, + } + if end in ('recv', 'send'): + kwargs[end] = True + interpreters.channel_close(cid, **kwargs) + return state.close(force=True) + else: + raise ValueError(action) + + +def clean_up_interpreters(): + for id in interpreters.list_all(): + if id == 0: # main + continue + try: + interpreters.destroy(id) + except RuntimeError: + pass # already destroyed + + +def clean_up_channels(): + for cid in interpreters.channel_list_all(): + try: + interpreters.channel_destroy(cid) + except interpreters.ChannelNotFoundError: + pass # already destroyed + + +class TestBase(unittest.TestCase): + + def tearDown(self): + clean_up_interpreters() + clean_up_channels() + + +################################## +# misc. tests + class IsShareableTests(unittest.TestCase): def test_default_shareables(self): @@ -59,6 +308,9 @@ class IsShareableTests(unittest.TestCase): None, # builtin objects b'spam', + 'spam', + 10, + -10, ] for obj in shareables: with self.subTest(obj): @@ -86,37 +338,65 @@ class IsShareableTests(unittest.TestCase): object, object(), Exception(), - 42, 100.0, - 'spam', # user-defined types and objects Cheese, Cheese('Wensleydale'), SubBytes(b'spam'), ] for obj in not_shareables: - with self.subTest(obj): + with self.subTest(repr(obj)): self.assertFalse( interpreters.is_shareable(obj)) -class TestBase(unittest.TestCase): +class ShareableTypeTests(unittest.TestCase): + + def setUp(self): + super().setUp() + self.cid = interpreters.channel_create() def tearDown(self): - for id in interpreters.list_all(): - if id == 0: # main - continue - try: - interpreters.destroy(id) - except RuntimeError: - pass # already destroyed + interpreters.channel_destroy(self.cid) + super().tearDown() - for cid in interpreters.channel_list_all(): - try: - interpreters.channel_destroy(cid) - except interpreters.ChannelNotFoundError: - pass # already destroyed + def _assert_values(self, values): + for obj in values: + with self.subTest(obj): + interpreters.channel_send(self.cid, obj) + got = interpreters.channel_recv(self.cid) + + self.assertEqual(got, obj) + self.assertIs(type(got), type(obj)) + # XXX Check the following in the channel tests? + #self.assertIsNot(got, obj) + + def test_singletons(self): + for obj in [None]: + with self.subTest(obj): + interpreters.channel_send(self.cid, obj) + got = interpreters.channel_recv(self.cid) + + # XXX What about between interpreters? + self.assertIs(got, obj) + + def test_types(self): + self._assert_values([ + b'spam', + 9999, + self.cid, + ]) + + def test_bytes(self): + self._assert_values(i.to_bytes(2, 'little', signed=True) + for i in range(-1, 258)) + def test_int(self): + self._assert_values(range(-1, 258)) + + +################################## +# interpreter tests class ListAllTests(TestBase): @@ -147,13 +427,16 @@ class GetCurrentTests(TestBase): main = interpreters.get_main() cur = interpreters.get_current() self.assertEqual(cur, main) + self.assertIsInstance(cur, interpreters.InterpreterID) def test_subinterpreter(self): main = interpreters.get_main() interp = interpreters.create() out = _run_output(interp, dedent(""" import _xxsubinterpreters as _interpreters - print(int(_interpreters.get_current())) + cur = _interpreters.get_current() + print(cur) + assert isinstance(cur, _interpreters.InterpreterID) """)) cur = int(out.strip()) _, expected = interpreters.list_all() @@ -167,13 +450,16 @@ class GetMainTests(TestBase): [expected] = interpreters.list_all() main = interpreters.get_main() self.assertEqual(main, expected) + self.assertIsInstance(main, interpreters.InterpreterID) def test_from_subinterpreter(self): [expected] = interpreters.list_all() interp = interpreters.create() out = _run_output(interp, dedent(""" import _xxsubinterpreters as _interpreters - print(int(_interpreters.get_main())) + main = _interpreters.get_main() + print(main) + assert isinstance(main, _interpreters.InterpreterID) """)) main = int(out.strip()) self.assertEqual(main, expected) @@ -197,7 +483,7 @@ class IsRunningTests(TestBase): interp = interpreters.create() out = _run_output(interp, dedent(f""" import _xxsubinterpreters as _interpreters - if _interpreters.is_running({int(interp)}): + if _interpreters.is_running({interp}): print(True) else: print(False) @@ -257,6 +543,10 @@ class InterpreterIDTests(TestBase): with self.assertRaises(RuntimeError): interpreters.InterpreterID(int(id) + 1) # unforced + def test_str(self): + id = interpreters.InterpreterID(10, force=True) + self.assertEqual(str(id), '10') + def test_repr(self): id = interpreters.InterpreterID(10, force=True) self.assertEqual(repr(id), 'InterpreterID(10)') @@ -280,6 +570,7 @@ class CreateTests(TestBase): def test_in_main(self): id = interpreters.create() + self.assertIsInstance(id, interpreters.InterpreterID) self.assertIn(id, interpreters.list_all()) @@ -314,7 +605,8 @@ class CreateTests(TestBase): out = _run_output(id1, dedent(""" import _xxsubinterpreters as _interpreters id = _interpreters.create() - print(int(id)) + print(id) + assert isinstance(id, _interpreters.InterpreterID) """)) id2 = int(out.strip()) @@ -329,7 +621,7 @@ class CreateTests(TestBase): out = _run_output(id1, dedent(""" import _xxsubinterpreters as _interpreters id = _interpreters.create() - print(int(id)) + print(id) """)) id2 = int(out.strip()) @@ -423,7 +715,7 @@ class DestroyTests(TestBase): script = dedent(f""" import _xxsubinterpreters as _interpreters try: - _interpreters.destroy({int(id)}) + _interpreters.destroy({id}) except RuntimeError: pass """) @@ -437,7 +729,7 @@ class DestroyTests(TestBase): id2 = interpreters.create() script = dedent(f""" import _xxsubinterpreters as _interpreters - _interpreters.destroy({int(id2)}) + _interpreters.destroy({id2}) """) interpreters.run_string(id1, script) @@ -783,6 +1075,9 @@ class RunStringTests(TestBase): self.assertEqual(retcode, 0) +################################## +# channel tests + class ChannelIDTests(TestBase): def test_default_kwargs(self): @@ -842,6 +1137,10 @@ class ChannelIDTests(TestBase): with self.assertRaises(interpreters.ChannelNotFoundError): interpreters._channel_id(int(cid) + 1) # unforced + def test_str(self): + cid = interpreters._channel_id(10, force=True) + self.assertEqual(str(cid), '10') + def test_repr(self): cid = interpreters._channel_id(10, force=True) self.assertEqual(repr(cid), 'ChannelID(10)') @@ -872,6 +1171,10 @@ class ChannelIDTests(TestBase): class ChannelTests(TestBase): + def test_create_cid(self): + cid = interpreters.channel_create() + self.assertIsInstance(cid, interpreters.ChannelID) + def test_sequential_ids(self): before = interpreters.channel_list_all() id1 = interpreters.channel_create() @@ -888,7 +1191,7 @@ class ChannelTests(TestBase): out = _run_output(id1, dedent(""" import _xxsubinterpreters as _interpreters cid = _interpreters.channel_create() - print(int(cid)) + print(cid) """)) cid1 = int(out.strip()) @@ -896,7 +1199,7 @@ class ChannelTests(TestBase): out = _run_output(id2, dedent(""" import _xxsubinterpreters as _interpreters cid = _interpreters.channel_create() - print(int(cid)) + print(cid) """)) cid2 = int(out.strip()) @@ -904,127 +1207,133 @@ class ChannelTests(TestBase): #################### - def test_drop_single_user(self): + def test_send_recv_main(self): cid = interpreters.channel_create() - interpreters.channel_send(cid, b'spam') - interpreters.channel_recv(cid) - interpreters.channel_drop_interpreter(cid, send=True, recv=True) + orig = b'spam' + interpreters.channel_send(cid, orig) + obj = interpreters.channel_recv(cid) - with self.assertRaises(interpreters.ChannelClosedError): - interpreters.channel_send(cid, b'eggs') - with self.assertRaises(interpreters.ChannelClosedError): - interpreters.channel_recv(cid) + self.assertEqual(obj, orig) + self.assertIsNot(obj, orig) - def test_drop_multiple_users(self): - cid = interpreters.channel_create() + def test_send_recv_same_interpreter(self): id1 = interpreters.create() - id2 = interpreters.create() - interpreters.run_string(id1, dedent(f""" + out = _run_output(id1, dedent(""" import _xxsubinterpreters as _interpreters - _interpreters.channel_send({int(cid)}, b'spam') + cid = _interpreters.channel_create() + orig = b'spam' + _interpreters.channel_send(cid, orig) + obj = _interpreters.channel_recv(cid) + assert obj is not orig + assert obj == orig """)) - out = _run_output(id2, dedent(f""" + + def test_send_recv_different_interpreters(self): + cid = interpreters.channel_create() + id1 = interpreters.create() + out = _run_output(id1, dedent(f""" import _xxsubinterpreters as _interpreters - obj = _interpreters.channel_recv({int(cid)}) - _interpreters.channel_drop_interpreter({int(cid)}) - print(repr(obj)) - """)) - interpreters.run_string(id1, dedent(f""" - _interpreters.channel_drop_interpreter({int(cid)}) + _interpreters.channel_send({cid}, b'spam') """)) + obj = interpreters.channel_recv(cid) - self.assertEqual(out.strip(), "b'spam'") + self.assertEqual(obj, b'spam') - def test_drop_no_kwargs(self): + def test_send_recv_different_threads(self): cid = interpreters.channel_create() - interpreters.channel_send(cid, b'spam') - interpreters.channel_recv(cid) - interpreters.channel_drop_interpreter(cid) - with self.assertRaises(interpreters.ChannelClosedError): - interpreters.channel_send(cid, b'eggs') - with self.assertRaises(interpreters.ChannelClosedError): - interpreters.channel_recv(cid) + def f(): + while True: + try: + obj = interpreters.channel_recv(cid) + break + except interpreters.ChannelEmptyError: + time.sleep(0.1) + interpreters.channel_send(cid, obj) + t = threading.Thread(target=f) + t.start() - def test_drop_multiple_times(self): - cid = interpreters.channel_create() interpreters.channel_send(cid, b'spam') - interpreters.channel_recv(cid) - interpreters.channel_drop_interpreter(cid, send=True, recv=True) + t.join() + obj = interpreters.channel_recv(cid) - with self.assertRaises(interpreters.ChannelClosedError): - interpreters.channel_drop_interpreter(cid, send=True, recv=True) + self.assertEqual(obj, b'spam') - def test_drop_with_unused_items(self): + def test_send_recv_different_interpreters_and_threads(self): cid = interpreters.channel_create() + id1 = interpreters.create() + out = None + + def f(): + nonlocal out + out = _run_output(id1, dedent(f""" + import time + import _xxsubinterpreters as _interpreters + while True: + try: + obj = _interpreters.channel_recv({cid}) + break + except _interpreters.ChannelEmptyError: + time.sleep(0.1) + assert(obj == b'spam') + _interpreters.channel_send({cid}, b'eggs') + """)) + t = threading.Thread(target=f) + t.start() + interpreters.channel_send(cid, b'spam') - interpreters.channel_send(cid, b'ham') - interpreters.channel_drop_interpreter(cid, send=True, recv=True) + t.join() + obj = interpreters.channel_recv(cid) - with self.assertRaises(interpreters.ChannelClosedError): - interpreters.channel_recv(cid) + self.assertEqual(obj, b'eggs') - def test_drop_never_used(self): - cid = interpreters.channel_create() - interpreters.channel_drop_interpreter(cid) + def test_send_not_found(self): + with self.assertRaises(interpreters.ChannelNotFoundError): + interpreters.channel_send(10, b'spam') - with self.assertRaises(interpreters.ChannelClosedError): - interpreters.channel_send(cid, b'spam') - with self.assertRaises(interpreters.ChannelClosedError): + def test_recv_not_found(self): + with self.assertRaises(interpreters.ChannelNotFoundError): + interpreters.channel_recv(10) + + def test_recv_empty(self): + cid = interpreters.channel_create() + with self.assertRaises(interpreters.ChannelEmptyError): interpreters.channel_recv(cid) - def test_drop_by_unassociated_interp(self): + def test_run_string_arg_unresolved(self): cid = interpreters.channel_create() - interpreters.channel_send(cid, b'spam') interp = interpreters.create() - interpreters.run_string(interp, dedent(f""" + + out = _run_output(interp, dedent(""" import _xxsubinterpreters as _interpreters - _interpreters.channel_drop_interpreter({int(cid)}) - """)) + print(cid.end) + _interpreters.channel_send(cid, b'spam') + """), + dict(cid=cid.send)) obj = interpreters.channel_recv(cid) - interpreters.channel_drop_interpreter(cid) - with self.assertRaises(interpreters.ChannelClosedError): - interpreters.channel_send(cid, b'eggs') self.assertEqual(obj, b'spam') + self.assertEqual(out.strip(), 'send') - def test_drop_close_if_unassociated(self): + def test_run_string_arg_resolved(self): cid = interpreters.channel_create() + cid = interpreters._channel_id(cid, _resolve=True) interp = interpreters.create() - interpreters.run_string(interp, dedent(f""" - import _xxsubinterpreters as _interpreters - obj = _interpreters.channel_send({int(cid)}, b'spam') - _interpreters.channel_drop_interpreter({int(cid)}) - """)) - - with self.assertRaises(interpreters.ChannelClosedError): - interpreters.channel_recv(cid) - def test_drop_partially(self): - # XXX Is partial close too weird/confusing? - cid = interpreters.channel_create() - interpreters.channel_send(cid, None) - interpreters.channel_recv(cid) - interpreters.channel_send(cid, b'spam') - interpreters.channel_drop_interpreter(cid, send=True) + out = _run_output(interp, dedent(""" + import _xxsubinterpreters as _interpreters + print(chan.end) + _interpreters.channel_send(chan, b'spam') + #print(chan.id.end) + #_interpreters.channel_send(chan.id, b'spam') + """), + dict(chan=cid.send)) obj = interpreters.channel_recv(cid) self.assertEqual(obj, b'spam') + self.assertEqual(out.strip(), 'send') - def test_drop_used_multiple_times_by_single_user(self): - cid = interpreters.channel_create() - interpreters.channel_send(cid, b'spam') - interpreters.channel_send(cid, b'spam') - interpreters.channel_send(cid, b'spam') - interpreters.channel_recv(cid) - interpreters.channel_drop_interpreter(cid, send=True, recv=True) - - with self.assertRaises(interpreters.ChannelClosedError): - interpreters.channel_send(cid, b'eggs') - with self.assertRaises(interpreters.ChannelClosedError): - interpreters.channel_recv(cid) - - #################### + # close def test_close_single_user(self): cid = interpreters.channel_create() @@ -1043,21 +1352,21 @@ class ChannelTests(TestBase): id2 = interpreters.create() interpreters.run_string(id1, dedent(f""" import _xxsubinterpreters as _interpreters - _interpreters.channel_send({int(cid)}, b'spam') + _interpreters.channel_send({cid}, b'spam') """)) interpreters.run_string(id2, dedent(f""" import _xxsubinterpreters as _interpreters - _interpreters.channel_recv({int(cid)}) + _interpreters.channel_recv({cid}) """)) interpreters.channel_close(cid) with self.assertRaises(interpreters.RunFailedError) as cm: interpreters.run_string(id1, dedent(f""" - _interpreters.channel_send({int(cid)}, b'spam') + _interpreters.channel_send({cid}, b'spam') """)) self.assertIn('ChannelClosedError', str(cm.exception)) with self.assertRaises(interpreters.RunFailedError) as cm: interpreters.run_string(id2, dedent(f""" - _interpreters.channel_send({int(cid)}, b'spam') + _interpreters.channel_send({cid}, b'spam') """)) self.assertIn('ChannelClosedError', str(cm.exception)) @@ -1094,7 +1403,7 @@ class ChannelTests(TestBase): interp = interpreters.create() interpreters.run_string(interp, dedent(f""" import _xxsubinterpreters as _interpreters - _interpreters.channel_close({int(cid)}) + _interpreters.channel_close({cid}) """)) with self.assertRaises(interpreters.ChannelClosedError): interpreters.channel_recv(cid) @@ -1114,115 +1423,602 @@ class ChannelTests(TestBase): with self.assertRaises(interpreters.ChannelClosedError): interpreters.channel_recv(cid) - #################### - def test_send_recv_main(self): +class ChannelReleaseTests(TestBase): + + # XXX Add more test coverage a la the tests for close(). + + """ + - main / interp / other + - run in: current thread / new thread / other thread / different threads + - end / opposite + - force / no force + - used / not used (associated / not associated) + - empty / emptied / never emptied / partly emptied + - closed / not closed + - released / not released + - creator (interp) / other + - associated interpreter not running + - associated interpreter destroyed + """ + + """ + use + pre-release + release + after + check + """ + + """ + release in: main, interp1 + creator: same, other (incl. interp2) + + use: None,send,recv,send/recv in None,same,other(incl. interp2),same+other(incl. interp2),all + pre-release: None,send,recv,both in None,same,other(incl. interp2),same+other(incl. interp2),all + pre-release forced: None,send,recv,both in None,same,other(incl. interp2),same+other(incl. interp2),all + + release: same + release forced: same + + use after: None,send,recv,send/recv in None,same,other(incl. interp2),same+other(incl. interp2),all + release after: None,send,recv,send/recv in None,same,other(incl. interp2),same+other(incl. interp2),all + check released: send/recv for same/other(incl. interp2) + check closed: send/recv for same/other(incl. interp2) + """ + + def test_single_user(self): cid = interpreters.channel_create() - orig = b'spam' - interpreters.channel_send(cid, orig) - obj = interpreters.channel_recv(cid) + interpreters.channel_send(cid, b'spam') + interpreters.channel_recv(cid) + interpreters.channel_release(cid, send=True, recv=True) - self.assertEqual(obj, orig) - self.assertIsNot(obj, orig) + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) - def test_send_recv_same_interpreter(self): + def test_multiple_users(self): + cid = interpreters.channel_create() id1 = interpreters.create() - out = _run_output(id1, dedent(""" + id2 = interpreters.create() + interpreters.run_string(id1, dedent(f""" import _xxsubinterpreters as _interpreters - cid = _interpreters.channel_create() - orig = b'spam' - _interpreters.channel_send(cid, orig) - obj = _interpreters.channel_recv(cid) - assert obj is not orig - assert obj == orig + _interpreters.channel_send({cid}, b'spam') + """)) + out = _run_output(id2, dedent(f""" + import _xxsubinterpreters as _interpreters + obj = _interpreters.channel_recv({cid}) + _interpreters.channel_release({cid}) + print(repr(obj)) + """)) + interpreters.run_string(id1, dedent(f""" + _interpreters.channel_release({cid}) """)) - def test_send_recv_different_interpreters(self): + self.assertEqual(out.strip(), "b'spam'") + + def test_no_kwargs(self): cid = interpreters.channel_create() - id1 = interpreters.create() - out = _run_output(id1, dedent(f""" + interpreters.channel_send(cid, b'spam') + interpreters.channel_recv(cid) + interpreters.channel_release(cid) + + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) + + def test_multiple_times(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_recv(cid) + interpreters.channel_release(cid, send=True, recv=True) + + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_release(cid, send=True, recv=True) + + def test_with_unused_items(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'ham') + interpreters.channel_release(cid, send=True, recv=True) + + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) + + def test_never_used(self): + cid = interpreters.channel_create() + interpreters.channel_release(cid) + + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'spam') + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) + + def test_by_unassociated_interp(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interp = interpreters.create() + interpreters.run_string(interp, dedent(f""" import _xxsubinterpreters as _interpreters - _interpreters.channel_send({int(cid)}, b'spam') + _interpreters.channel_release({cid}) """)) obj = interpreters.channel_recv(cid) + interpreters.channel_release(cid) + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') self.assertEqual(obj, b'spam') - def test_send_recv_different_threads(self): + def test_close_if_unassociated(self): + # XXX Something's not right with this test... cid = interpreters.channel_create() + interp = interpreters.create() + interpreters.run_string(interp, dedent(f""" + import _xxsubinterpreters as _interpreters + obj = _interpreters.channel_send({cid}, b'spam') + _interpreters.channel_release({cid}) + """)) - def f(): - while True: - try: - obj = interpreters.channel_recv(cid) - break - except interpreters.ChannelEmptyError: - time.sleep(0.1) - interpreters.channel_send(cid, obj) - t = threading.Thread(target=f) - t.start() + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) + def test_partially(self): + # XXX Is partial close too weird/confusing? + cid = interpreters.channel_create() + interpreters.channel_send(cid, None) + interpreters.channel_recv(cid) interpreters.channel_send(cid, b'spam') - t.join() + interpreters.channel_release(cid, send=True) obj = interpreters.channel_recv(cid) self.assertEqual(obj, b'spam') - def test_send_recv_different_interpreters_and_threads(self): + def test_used_multiple_times_by_single_user(self): cid = interpreters.channel_create() - id1 = interpreters.create() - out = None + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'spam') + interpreters.channel_recv(cid) + interpreters.channel_release(cid, send=True, recv=True) - def f(): - nonlocal out - out = _run_output(id1, dedent(f""" - import time - import _xxsubinterpreters as _interpreters - while True: - try: - obj = _interpreters.channel_recv({int(cid)}) - break - except _interpreters.ChannelEmptyError: - time.sleep(0.1) - assert(obj == b'spam') - _interpreters.channel_send({int(cid)}, b'eggs') - """)) - t = threading.Thread(target=f) - t.start() + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) - interpreters.channel_send(cid, b'spam') - t.join() - obj = interpreters.channel_recv(cid) - self.assertEqual(obj, b'eggs') +class ChannelCloseFixture(namedtuple('ChannelCloseFixture', + 'end interp other extra creator')): - def test_send_not_found(self): - with self.assertRaises(interpreters.ChannelNotFoundError): - interpreters.channel_send(10, b'spam') + # Set this to True to avoid creating interpreters, e.g. when + # scanning through test permutations without running them. + QUICK = False - def test_recv_not_found(self): - with self.assertRaises(interpreters.ChannelNotFoundError): - interpreters.channel_recv(10) + def __new__(cls, end, interp, other, extra, creator): + assert end in ('send', 'recv') + if cls.QUICK: + known = {} + else: + interp = Interpreter.from_raw(interp) + other = Interpreter.from_raw(other) + extra = Interpreter.from_raw(extra) + known = { + interp.name: interp, + other.name: other, + extra.name: extra, + } + if not creator: + creator = 'same' + self = super().__new__(cls, end, interp, other, extra, creator) + self._prepped = set() + self._state = ChannelState() + self._known = known + return self - def test_recv_empty(self): - cid = interpreters.channel_create() - with self.assertRaises(interpreters.ChannelEmptyError): - interpreters.channel_recv(cid) + @property + def state(self): + return self._state - def test_run_string_arg(self): - cid = interpreters.channel_create() - interp = interpreters.create() + @property + def cid(self): + try: + return self._cid + except AttributeError: + creator = self._get_interpreter(self.creator) + self._cid = self._new_channel(creator) + return self._cid + + def get_interpreter(self, interp): + interp = self._get_interpreter(interp) + self._prep_interpreter(interp) + return interp + + def expect_closed_error(self, end=None): + if end is None: + end = self.end + if end == 'recv' and self.state.closed == 'send': + return False + return bool(self.state.closed) + + def prep_interpreter(self, interp): + self._prep_interpreter(interp) + + def record_action(self, action, result): + self._state = result + + def clean_up(self): + clean_up_interpreters() + clean_up_channels() + + # internal methods + + def _new_channel(self, creator): + if creator.name == 'main': + return interpreters.channel_create() + else: + ch = interpreters.channel_create() + run_interp(creator.id, f""" + import _xxsubinterpreters + cid = _xxsubinterpreters.channel_create() + # We purposefully send back an int to avoid tying the + # channel to the other interpreter. + _xxsubinterpreters.channel_send({ch}, int(cid)) + del _xxsubinterpreters + """) + self._cid = interpreters.channel_recv(ch) + return self._cid + + def _get_interpreter(self, interp): + if interp in ('same', 'interp'): + return self.interp + elif interp == 'other': + return self.other + elif interp == 'extra': + return self.extra + else: + name = interp + try: + interp = self._known[name] + except KeyError: + interp = self._known[name] = Interpreter(name) + return interp + + def _prep_interpreter(self, interp): + if interp.id in self._prepped: + return + self._prepped.add(interp.id) + if interp.name == 'main': + return + run_interp(interp.id, f""" + import _xxsubinterpreters as interpreters + import test.test__xxsubinterpreters as helpers + ChannelState = helpers.ChannelState + try: + cid + except NameError: + cid = interpreters._channel_id({self.cid}) + """) - out = _run_output(interp, dedent(""" - import _xxsubinterpreters as _interpreters - print(cid.end) - _interpreters.channel_send(cid, b'spam') - """), - dict(cid=cid.send)) - obj = interpreters.channel_recv(cid) - self.assertEqual(obj, b'spam') - self.assertEqual(out.strip(), 'send') +@unittest.skip('these tests take several hours to run') +class ExhaustiveChannelTests(TestBase): + + """ + - main / interp / other + - run in: current thread / new thread / other thread / different threads + - end / opposite + - force / no force + - used / not used (associated / not associated) + - empty / emptied / never emptied / partly emptied + - closed / not closed + - released / not released + - creator (interp) / other + - associated interpreter not running + - associated interpreter destroyed + + - close after unbound + """ + + """ + use + pre-close + close + after + check + """ + + """ + close in: main, interp1 + creator: same, other, extra + + use: None,send,recv,send/recv in None,same,other,same+other,all + pre-close: None,send,recv in None,same,other,same+other,all + pre-close forced: None,send,recv in None,same,other,same+other,all + + close: same + close forced: same + + use after: None,send,recv,send/recv in None,same,other,extra,same+other,all + close after: None,send,recv,send/recv in None,same,other,extra,same+other,all + check closed: send/recv for same/other(incl. interp2) + """ + + def iter_action_sets(self): + # - used / not used (associated / not associated) + # - empty / emptied / never emptied / partly emptied + # - closed / not closed + # - released / not released + + # never used + yield [] + + # only pre-closed (and possible used after) + for closeactions in self._iter_close_action_sets('same', 'other'): + yield closeactions + for postactions in self._iter_post_close_action_sets(): + yield closeactions + postactions + for closeactions in self._iter_close_action_sets('other', 'extra'): + yield closeactions + for postactions in self._iter_post_close_action_sets(): + yield closeactions + postactions + + # used + for useactions in self._iter_use_action_sets('same', 'other'): + yield useactions + for closeactions in self._iter_close_action_sets('same', 'other'): + actions = useactions + closeactions + yield actions + for postactions in self._iter_post_close_action_sets(): + yield actions + postactions + for closeactions in self._iter_close_action_sets('other', 'extra'): + actions = useactions + closeactions + yield actions + for postactions in self._iter_post_close_action_sets(): + yield actions + postactions + for useactions in self._iter_use_action_sets('other', 'extra'): + yield useactions + for closeactions in self._iter_close_action_sets('same', 'other'): + actions = useactions + closeactions + yield actions + for postactions in self._iter_post_close_action_sets(): + yield actions + postactions + for closeactions in self._iter_close_action_sets('other', 'extra'): + actions = useactions + closeactions + yield actions + for postactions in self._iter_post_close_action_sets(): + yield actions + postactions + + def _iter_use_action_sets(self, interp1, interp2): + interps = (interp1, interp2) + + # only recv end used + yield [ + ChannelAction('use', 'recv', interp1), + ] + yield [ + ChannelAction('use', 'recv', interp2), + ] + yield [ + ChannelAction('use', 'recv', interp1), + ChannelAction('use', 'recv', interp2), + ] + + # never emptied + yield [ + ChannelAction('use', 'send', interp1), + ] + yield [ + ChannelAction('use', 'send', interp2), + ] + yield [ + ChannelAction('use', 'send', interp1), + ChannelAction('use', 'send', interp2), + ] + + # partially emptied + for interp1 in interps: + for interp2 in interps: + for interp3 in interps: + yield [ + ChannelAction('use', 'send', interp1), + ChannelAction('use', 'send', interp2), + ChannelAction('use', 'recv', interp3), + ] + + # fully emptied + for interp1 in interps: + for interp2 in interps: + for interp3 in interps: + for interp4 in interps: + yield [ + ChannelAction('use', 'send', interp1), + ChannelAction('use', 'send', interp2), + ChannelAction('use', 'recv', interp3), + ChannelAction('use', 'recv', interp4), + ] + + def _iter_close_action_sets(self, interp1, interp2): + ends = ('recv', 'send') + interps = (interp1, interp2) + for force in (True, False): + op = 'force-close' if force else 'close' + for interp in interps: + for end in ends: + yield [ + ChannelAction(op, end, interp), + ] + for recvop in ('close', 'force-close'): + for sendop in ('close', 'force-close'): + for recv in interps: + for send in interps: + yield [ + ChannelAction(recvop, 'recv', recv), + ChannelAction(sendop, 'send', send), + ] + + def _iter_post_close_action_sets(self): + for interp in ('same', 'extra', 'other'): + yield [ + ChannelAction('use', 'recv', interp), + ] + yield [ + ChannelAction('use', 'send', interp), + ] + + def run_actions(self, fix, actions): + for action in actions: + self.run_action(fix, action) + + def run_action(self, fix, action, *, hideclosed=True): + end = action.resolve_end(fix.end) + interp = action.resolve_interp(fix.interp, fix.other, fix.extra) + fix.prep_interpreter(interp) + if interp.name == 'main': + result = run_action( + fix.cid, + action.action, + end, + fix.state, + hideclosed=hideclosed, + ) + fix.record_action(action, result) + else: + _cid = interpreters.channel_create() + run_interp(interp.id, f""" + result = helpers.run_action( + {fix.cid}, + {repr(action.action)}, + {repr(end)}, + {repr(fix.state)}, + hideclosed={hideclosed}, + ) + interpreters.channel_send({_cid}, result.pending.to_bytes(1, 'little')) + interpreters.channel_send({_cid}, b'X' if result.closed else b'') + """) + result = ChannelState( + pending=int.from_bytes(interpreters.channel_recv(_cid), 'little'), + closed=bool(interpreters.channel_recv(_cid)), + ) + fix.record_action(action, result) + + def iter_fixtures(self): + # XXX threads? + interpreters = [ + ('main', 'interp', 'extra'), + ('interp', 'main', 'extra'), + ('interp1', 'interp2', 'extra'), + ('interp1', 'interp2', 'main'), + ] + for interp, other, extra in interpreters: + for creator in ('same', 'other', 'creator'): + for end in ('send', 'recv'): + yield ChannelCloseFixture(end, interp, other, extra, creator) + + def _close(self, fix, *, force): + op = 'force-close' if force else 'close' + close = ChannelAction(op, fix.end, 'same') + if not fix.expect_closed_error(): + self.run_action(fix, close, hideclosed=False) + else: + with self.assertRaises(interpreters.ChannelClosedError): + self.run_action(fix, close, hideclosed=False) + + def _assert_closed_in_interp(self, fix, interp=None): + if interp is None or interp.name == 'main': + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(fix.cid) + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(fix.cid, b'spam') + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_close(fix.cid) + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_close(fix.cid, force=True) + else: + run_interp(interp.id, f""" + with helpers.expect_channel_closed(): + interpreters.channel_recv(cid) + """) + run_interp(interp.id, f""" + with helpers.expect_channel_closed(): + interpreters.channel_send(cid, b'spam') + """) + run_interp(interp.id, f""" + with helpers.expect_channel_closed(): + interpreters.channel_close(cid) + """) + run_interp(interp.id, f""" + with helpers.expect_channel_closed(): + interpreters.channel_close(cid, force=True) + """) + + def _assert_closed(self, fix): + self.assertTrue(fix.state.closed) + + for _ in range(fix.state.pending): + interpreters.channel_recv(fix.cid) + self._assert_closed_in_interp(fix) + + for interp in ('same', 'other'): + interp = fix.get_interpreter(interp) + if interp.name == 'main': + continue + self._assert_closed_in_interp(fix, interp) + + interp = fix.get_interpreter('fresh') + self._assert_closed_in_interp(fix, interp) + + def _iter_close_tests(self, verbose=False): + i = 0 + for actions in self.iter_action_sets(): + print() + for fix in self.iter_fixtures(): + i += 1 + if i > 1000: + return + if verbose: + if (i - 1) % 6 == 0: + print() + print(i, fix, '({} actions)'.format(len(actions))) + else: + if (i - 1) % 6 == 0: + print(' ', end='') + print('.', end=''); sys.stdout.flush() + yield i, fix, actions + if verbose: + print('---') + print() + + # This is useful for scanning through the possible tests. + def _skim_close_tests(self): + ChannelCloseFixture.QUICK = True + for i, fix, actions in self._iter_close_tests(): + pass + + def test_close(self): + for i, fix, actions in self._iter_close_tests(): + with self.subTest('{} {} {}'.format(i, fix, actions)): + fix.prep_interpreter(fix.interp) + self.run_actions(fix, actions) + + self._close(fix, force=False) + + self._assert_closed(fix) + # XXX Things slow down if we have too many interpreters. + fix.clean_up() + + def test_force_close(self): + for i, fix, actions in self._iter_close_tests(): + with self.subTest('{} {} {}'.format(i, fix, actions)): + fix.prep_interpreter(fix.interp) + self.run_actions(fix, actions) + + self._close(fix, force=True) + + self._assert_closed(fix) + # XXX Things slow down if we have too many interpreters. + fix.clean_up() if __name__ == '__main__': diff --git a/Modules/_xxsubinterpretersmodule.c b/Modules/_xxsubinterpretersmodule.c index f5e2ea3..5184f65 100644 --- a/Modules/_xxsubinterpretersmodule.c +++ b/Modules/_xxsubinterpretersmodule.c @@ -1250,7 +1250,9 @@ _channel_recv(_channels *channels, int64_t id) _PyCrossInterpreterData *data = _channel_next(chan, interp->id); PyThread_release_lock(mutex); if (data == NULL) { - PyErr_Format(ChannelEmptyError, "channel %d is empty", id); + if (!PyErr_Occurred()) { + PyErr_Format(ChannelEmptyError, "channel %d is empty", id); + } return NULL; } @@ -1304,12 +1306,13 @@ typedef struct channelid { PyObject_HEAD int64_t id; int end; + int resolve; _channels *channels; } channelid; static channelid * newchannelid(PyTypeObject *cls, int64_t cid, int end, _channels *channels, - int force) + int force, int resolve) { channelid *self = PyObject_New(channelid, cls); if (self == NULL) { @@ -1317,6 +1320,7 @@ newchannelid(PyTypeObject *cls, int64_t cid, int end, _channels *channels, } self->id = cid; self->end = end; + self->resolve = resolve; self->channels = channels; if (_channels_add_id_object(channels, cid) != 0) { @@ -1337,14 +1341,15 @@ static _channels * _global_channels(void); static PyObject * channelid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds) { - static char *kwlist[] = {"id", "send", "recv", "force", NULL}; + static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL}; PyObject *id; int send = -1; int recv = -1; int force = 0; + int resolve = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "O|$ppp:ChannelID.__init__", kwlist, - &id, &send, &recv, &force)) + "O|$pppp:ChannelID.__new__", kwlist, + &id, &send, &recv, &force, &resolve)) return NULL; // Coerce and check the ID. @@ -1376,7 +1381,8 @@ channelid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds) end = CHANNEL_RECV; } - return (PyObject *)newchannelid(cls, cid, end, _global_channels(), force); + return (PyObject *)newchannelid(cls, cid, end, _global_channels(), + force, resolve); } static void @@ -1409,6 +1415,13 @@ channelid_repr(PyObject *self) return PyUnicode_FromFormat(fmt, name, cid->id); } +static PyObject * +channelid_str(PyObject *self) +{ + channelid *cid = (channelid *)self; + return PyUnicode_FromFormat("%d", cid->id); +} + PyObject * channelid_int(PyObject *self) { @@ -1519,14 +1532,49 @@ channelid_richcompare(PyObject *self, PyObject *other, int op) struct _channelid_xid { int64_t id; int end; + int resolve; }; static PyObject * _channelid_from_xid(_PyCrossInterpreterData *data) { struct _channelid_xid *xid = (struct _channelid_xid *)data->data; - return (PyObject *)newchannelid(&ChannelIDtype, xid->id, xid->end, - _global_channels(), 0); + // Note that we do not preserve the "resolve" flag. + PyObject *cid = (PyObject *)newchannelid(&ChannelIDtype, xid->id, xid->end, + _global_channels(), 0, 0); + if (xid->end == 0) { + return cid; + } + if (!xid->resolve) { + return cid; + } + + /* Try returning a high-level channel end but fall back to the ID. */ + PyObject *highlevel = PyImport_ImportModule("interpreters"); + if (highlevel == NULL) { + PyErr_Clear(); + highlevel = PyImport_ImportModule("test.support.interpreters"); + if (highlevel == NULL) { + goto error; + } + } + const char *clsname = (xid->end == CHANNEL_RECV) ? "RecvChannel" : + "SendChannel"; + PyObject *cls = PyObject_GetAttrString(highlevel, clsname); + Py_DECREF(highlevel); + if (cls == NULL) { + goto error; + } + PyObject *chan = PyObject_CallFunctionObjArgs(cls, cid, NULL); + if (chan == NULL) { + goto error; + } + Py_DECREF(cid); + return chan; + +error: + PyErr_Clear(); + return cid; } static int @@ -1538,6 +1586,7 @@ _channelid_shared(PyObject *obj, _PyCrossInterpreterData *data) } xid->id = ((channelid *)obj)->id; xid->end = ((channelid *)obj)->end; + xid->resolve = ((channelid *)obj)->resolve; data->data = xid; data->obj = obj; @@ -1553,7 +1602,7 @@ channelid_end(PyObject *self, void *end) channelid *cid = (channelid *)self; if (end != NULL) { return (PyObject *)newchannelid(Py_TYPE(self), cid->id, *(int *)end, - cid->channels, force); + cid->channels, force, cid->resolve); } if (cid->end == CHANNEL_SEND) { @@ -1597,7 +1646,7 @@ static PyTypeObject ChannelIDtype = { 0, /* tp_as_mapping */ channelid_hash, /* tp_hash */ 0, /* tp_call */ - 0, /* tp_str */ + (reprfunc)channelid_str, /* tp_str */ 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ @@ -1878,6 +1927,13 @@ interpid_repr(PyObject *self) return PyUnicode_FromFormat("%s(%d)", name, id->id); } +static PyObject * +interpid_str(PyObject *self) +{ + interpid *id = (interpid *)self; + return PyUnicode_FromFormat("%d", id->id); +} + PyObject * interpid_int(PyObject *self) { @@ -1999,7 +2055,7 @@ static PyTypeObject InterpreterIDtype = { 0, /* tp_as_mapping */ interpid_hash, /* tp_hash */ 0, /* tp_call */ - 0, /* tp_str */ + (reprfunc)interpid_str, /* tp_str */ 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ @@ -2115,10 +2171,13 @@ Create a new interpreter and return a unique generated ID."); static PyObject * -interp_destroy(PyObject *self, PyObject *args) +interp_destroy(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"id", NULL}; PyObject *id; - if (!PyArg_UnpackTuple(args, "destroy", 1, 1, &id)) { + // XXX Use "L" for id? + if (!PyArg_ParseTupleAndKeywords(args, kwds, + "O:destroy", kwlist, &id)) { return NULL; } if (!PyLong_Check(id)) { @@ -2162,7 +2221,7 @@ interp_destroy(PyObject *self, PyObject *args) } PyDoc_STRVAR(destroy_doc, -"destroy(ID)\n\ +"destroy(id)\n\ \n\ Destroy the identified interpreter.\n\ \n\ @@ -2228,7 +2287,8 @@ static PyObject * interp_get_main(PyObject *self, PyObject *Py_UNUSED(ignored)) { // Currently, 0 is always the main interpreter. - return PyLong_FromLongLong(0); + PY_INT64_T id = 0; + return (PyObject *)newinterpid(&InterpreterIDtype, id, 0); } PyDoc_STRVAR(get_main_doc, @@ -2238,22 +2298,20 @@ Return the ID of main interpreter."); static PyObject * -interp_run_string(PyObject *self, PyObject *args) +interp_run_string(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"id", "script", "shared", NULL}; PyObject *id, *code; PyObject *shared = NULL; - if (!PyArg_UnpackTuple(args, "run_string", 2, 3, &id, &code, &shared)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, + "OU|O:run_string", kwlist, + &id, &code, &shared)) { return NULL; } if (!PyLong_Check(id)) { PyErr_SetString(PyExc_TypeError, "first arg (ID) must be an int"); return NULL; } - if (!PyUnicode_Check(code)) { - PyErr_SetString(PyExc_TypeError, - "second arg (code) must be a string"); - return NULL; - } // Look up the interpreter. PyInterpreterState *interp = _look_up(id); @@ -2281,7 +2339,7 @@ interp_run_string(PyObject *self, PyObject *args) } PyDoc_STRVAR(run_string_doc, -"run_string(ID, sourcetext)\n\ +"run_string(id, script, shared)\n\ \n\ Execute the provided string in the identified interpreter.\n\ \n\ @@ -2289,12 +2347,15 @@ See PyRun_SimpleStrings."); static PyObject * -object_is_shareable(PyObject *self, PyObject *args) +object_is_shareable(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"obj", NULL}; PyObject *obj; - if (!PyArg_UnpackTuple(args, "is_shareable", 1, 1, &obj)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, + "O:is_shareable", kwlist, &obj)) { return NULL; } + if (_PyObject_CheckCrossInterpreterData(obj) == 0) { Py_RETURN_TRUE; } @@ -2310,10 +2371,12 @@ False otherwise."); static PyObject * -interp_is_running(PyObject *self, PyObject *args) +interp_is_running(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"id", NULL}; PyObject *id; - if (!PyArg_UnpackTuple(args, "is_running", 1, 1, &id)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, + "O:is_running", kwlist, &id)) { return NULL; } if (!PyLong_Check(id)) { @@ -2348,7 +2411,7 @@ channel_create(PyObject *self, PyObject *Py_UNUSED(ignored)) return NULL; } PyObject *id = (PyObject *)newchannelid(&ChannelIDtype, cid, 0, - &_globals.channels, 0); + &_globals.channels, 0, 0); if (id == NULL) { if (_channel_destroy(&_globals.channels, cid) != 0) { // XXX issue a warning? @@ -2360,15 +2423,17 @@ channel_create(PyObject *self, PyObject *Py_UNUSED(ignored)) } PyDoc_STRVAR(channel_create_doc, -"channel_create() -> ID\n\ +"channel_create() -> cid\n\ \n\ Create a new cross-interpreter channel and return a unique generated ID."); static PyObject * -channel_destroy(PyObject *self, PyObject *args) +channel_destroy(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"cid", NULL}; PyObject *id; - if (!PyArg_UnpackTuple(args, "channel_destroy", 1, 1, &id)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, + "O:channel_destroy", kwlist, &id)) { return NULL; } int64_t cid = _coerce_id(id); @@ -2383,7 +2448,7 @@ channel_destroy(PyObject *self, PyObject *args) } PyDoc_STRVAR(channel_destroy_doc, -"channel_destroy(ID)\n\ +"channel_destroy(cid)\n\ \n\ Close and finalize the channel. Afterward attempts to use the channel\n\ will behave as though it never existed."); @@ -2406,7 +2471,7 @@ channel_list_all(PyObject *self, PyObject *Py_UNUSED(ignored)) int64_t *cur = cids; for (int64_t i=0; i < count; cur++, i++) { PyObject *id = (PyObject *)newchannelid(&ChannelIDtype, *cur, 0, - &_globals.channels, 0); + &_globals.channels, 0, 0); if (id == NULL) { Py_DECREF(ids); ids = NULL; @@ -2421,16 +2486,18 @@ finally: } PyDoc_STRVAR(channel_list_all_doc, -"channel_list_all() -> [ID]\n\ +"channel_list_all() -> [cid]\n\ \n\ Return the list of all IDs for active channels."); static PyObject * -channel_send(PyObject *self, PyObject *args) +channel_send(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"cid", "obj", NULL}; PyObject *id; PyObject *obj; - if (!PyArg_UnpackTuple(args, "channel_send", 2, 2, &id, &obj)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, + "OO:channel_send", kwlist, &id, &obj)) { return NULL; } int64_t cid = _coerce_id(id); @@ -2445,15 +2512,17 @@ channel_send(PyObject *self, PyObject *args) } PyDoc_STRVAR(channel_send_doc, -"channel_send(ID, obj)\n\ +"channel_send(cid, obj)\n\ \n\ Add the object's data to the channel's queue."); static PyObject * -channel_recv(PyObject *self, PyObject *args) +channel_recv(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"cid", NULL}; PyObject *id; - if (!PyArg_UnpackTuple(args, "channel_recv", 1, 1, &id)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, + "O:channel_recv", kwlist, &id)) { return NULL; } int64_t cid = _coerce_id(id); @@ -2465,17 +2534,34 @@ channel_recv(PyObject *self, PyObject *args) } PyDoc_STRVAR(channel_recv_doc, -"channel_recv(ID) -> obj\n\ +"channel_recv(cid) -> obj\n\ \n\ Return a new object from the data at the from of the channel's queue."); static PyObject * -channel_close(PyObject *self, PyObject *id) +channel_close(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"cid", "send", "recv", "force", NULL}; + PyObject *id; + int send = 0; + int recv = 0; + int force = 0; + if (!PyArg_ParseTupleAndKeywords(args, kwds, + "O|$ppp:channel_close", kwlist, + &id, &send, &recv, &force)) { + return NULL; + } int64_t cid = _coerce_id(id); if (cid < 0) { return NULL; } + if (send == 0 && recv == 0) { + send = 1; + recv = 1; + } + + // XXX Handle the ends. + // XXX Handle force is True. if (_channel_close(&_globals.channels, cid) != 0) { return NULL; @@ -2484,48 +2570,66 @@ channel_close(PyObject *self, PyObject *id) } PyDoc_STRVAR(channel_close_doc, -"channel_close(ID)\n\ +"channel_close(cid, *, send=None, recv=None, force=False)\n\ +\n\ +Close the channel for all interpreters.\n\ \n\ -Close the channel for all interpreters. Once the channel's ID has\n\ -no more ref counts the channel will be destroyed."); +If the channel is empty then the keyword args are ignored and both\n\ +ends are immediately closed. Otherwise, if 'force' is True then\n\ +all queued items are released and both ends are immediately\n\ +closed.\n\ +\n\ +If the channel is not empty *and* 'force' is False then following\n\ +happens:\n\ +\n\ + * recv is True (regardless of send):\n\ + - raise ChannelNotEmptyError\n\ + * recv is None and send is None:\n\ + - raise ChannelNotEmptyError\n\ + * send is True and recv is not True:\n\ + - fully close the 'send' end\n\ + - close the 'recv' end to interpreters not already receiving\n\ + - fully close it once empty\n\ +\n\ +Closing an already closed channel results in a ChannelClosedError.\n\ +\n\ +Once the channel's ID has no more ref counts in any interpreter\n\ +the channel will be destroyed."); static PyObject * -channel_drop_interpreter(PyObject *self, PyObject *args, PyObject *kwds) +channel_release(PyObject *self, PyObject *args, PyObject *kwds) { // Note that only the current interpreter is affected. - static char *kwlist[] = {"id", "send", "recv", NULL}; + static char *kwlist[] = {"cid", "send", "recv", "force", NULL}; PyObject *id; - int send = -1; - int recv = -1; + int send = 0; + int recv = 0; + int force = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "O|$pp:channel_drop_interpreter", kwlist, - &id, &send, &recv)) + "O|$ppp:channel_release", kwlist, + &id, &send, &recv, &force)) { return NULL; - + } int64_t cid = _coerce_id(id); if (cid < 0) { return NULL; } - if (send < 0 && recv < 0) { + if (send == 0 && recv == 0) { send = 1; recv = 1; } - else { - if (send < 0) { - send = 0; - } - if (recv < 0) { - recv = 0; - } - } + + // XXX Handle force is True. + // XXX Fix implicit release. + if (_channel_drop(&_globals.channels, cid, send, recv) != 0) { return NULL; } Py_RETURN_NONE; } -PyDoc_STRVAR(channel_drop_interpreter_doc, -"channel_drop_interpreter(ID, *, send=None, recv=None)\n\ +PyDoc_STRVAR(channel_release_doc, +"channel_release(cid, *, send=None, recv=None, force=True)\n\ \n\ Close the channel for the current interpreter. 'send' and 'recv'\n\ (bool) may be used to indicate the ends to close. By default both\n\ @@ -2541,7 +2645,7 @@ static PyMethodDef module_functions[] = { {"create", (PyCFunction)interp_create, METH_VARARGS, create_doc}, {"destroy", (PyCFunction)interp_destroy, - METH_VARARGS, destroy_doc}, + METH_VARARGS | METH_KEYWORDS, destroy_doc}, {"list_all", interp_list_all, METH_NOARGS, list_all_doc}, {"get_current", interp_get_current, @@ -2549,27 +2653,27 @@ static PyMethodDef module_functions[] = { {"get_main", interp_get_main, METH_NOARGS, get_main_doc}, {"is_running", (PyCFunction)interp_is_running, - METH_VARARGS, is_running_doc}, + METH_VARARGS | METH_KEYWORDS, is_running_doc}, {"run_string", (PyCFunction)interp_run_string, - METH_VARARGS, run_string_doc}, + METH_VARARGS | METH_KEYWORDS, run_string_doc}, {"is_shareable", (PyCFunction)object_is_shareable, - METH_VARARGS, is_shareable_doc}, + METH_VARARGS | METH_KEYWORDS, is_shareable_doc}, {"channel_create", channel_create, METH_NOARGS, channel_create_doc}, {"channel_destroy", (PyCFunction)channel_destroy, - METH_VARARGS, channel_destroy_doc}, + METH_VARARGS | METH_KEYWORDS, channel_destroy_doc}, {"channel_list_all", channel_list_all, METH_NOARGS, channel_list_all_doc}, {"channel_send", (PyCFunction)channel_send, - METH_VARARGS, channel_send_doc}, + METH_VARARGS | METH_KEYWORDS, channel_send_doc}, {"channel_recv", (PyCFunction)channel_recv, - METH_VARARGS, channel_recv_doc}, - {"channel_close", channel_close, - METH_O, channel_close_doc}, - {"channel_drop_interpreter", (PyCFunction)channel_drop_interpreter, - METH_VARARGS | METH_KEYWORDS, channel_drop_interpreter_doc}, + METH_VARARGS | METH_KEYWORDS, channel_recv_doc}, + {"channel_close", (PyCFunction)channel_close, + METH_VARARGS | METH_KEYWORDS, channel_close_doc}, + {"channel_release", (PyCFunction)channel_release, + METH_VARARGS | METH_KEYWORDS, channel_release_doc}, {"_channel_id", (PyCFunction)channel__channel_id, METH_VARARGS | METH_KEYWORDS, NULL}, diff --git a/Python/pystate.c b/Python/pystate.c index 140d2fb..d276bfc 100644 --- a/Python/pystate.c +++ b/Python/pystate.c @@ -1308,6 +1308,10 @@ _PyCrossInterpreterData_Register_Class(PyTypeObject *cls, return res; } +/* Cross-interpreter objects are looked up by exact match on the class. + We can reassess this policy when we move from a global registry to a + tp_* slot. */ + crossinterpdatafunc _PyCrossInterpreterData_Lookup(PyObject *obj) { @@ -1332,19 +1336,79 @@ _PyCrossInterpreterData_Lookup(PyObject *obj) /* cross-interpreter data for builtin types */ +struct _shared_bytes_data { + char *bytes; + Py_ssize_t len; +}; + static PyObject * _new_bytes_object(_PyCrossInterpreterData *data) { - return PyBytes_FromString((char *)(data->data)); + struct _shared_bytes_data *shared = (struct _shared_bytes_data *)(data->data); + return PyBytes_FromStringAndSize(shared->bytes, shared->len); } static int _bytes_shared(PyObject *obj, _PyCrossInterpreterData *data) { - data->data = (void *)(PyBytes_AS_STRING(obj)); + struct _shared_bytes_data *shared = PyMem_NEW(struct _shared_bytes_data, 1); + if (PyBytes_AsStringAndSize(obj, &shared->bytes, &shared->len) < 0) { + return -1; + } + data->data = (void *)shared; data->obj = obj; // Will be "released" (decref'ed) when data released. data->new_object = _new_bytes_object; - data->free = NULL; // Do not free the data (it belongs to the object). + data->free = PyMem_Free; + return 0; +} + +struct _shared_str_data { + int kind; + const void *buffer; + Py_ssize_t len; +}; + +static PyObject * +_new_str_object(_PyCrossInterpreterData *data) +{ + struct _shared_str_data *shared = (struct _shared_str_data *)(data->data); + return PyUnicode_FromKindAndData(shared->kind, shared->buffer, shared->len); +} + +static int +_str_shared(PyObject *obj, _PyCrossInterpreterData *data) +{ + struct _shared_str_data *shared = PyMem_NEW(struct _shared_str_data, 1); + shared->kind = PyUnicode_KIND(obj); + shared->buffer = PyUnicode_DATA(obj); + shared->len = PyUnicode_GET_LENGTH(obj) - 1; + data->data = (void *)shared; + data->obj = obj; // Will be "released" (decref'ed) when data released. + data->new_object = _new_str_object; + data->free = PyMem_Free; + return 0; +} + +static PyObject * +_new_long_object(_PyCrossInterpreterData *data) +{ + return PyLong_FromLongLong((int64_t)(data->data)); +} + +static int +_long_shared(PyObject *obj, _PyCrossInterpreterData *data) +{ + int64_t value = PyLong_AsLongLong(obj); + if (value == -1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + PyErr_SetString(PyExc_OverflowError, "try sending as bytes"); + } + return -1; + } + data->data = (void *)value; + data->obj = NULL; + data->new_object = _new_long_object; + data->free = NULL; return 0; } @@ -1374,10 +1438,20 @@ _register_builtins_for_crossinterpreter_data(void) Py_FatalError("could not register None for cross-interpreter sharing"); } + // int + if (_register_xidata(&PyLong_Type, _long_shared) != 0) { + Py_FatalError("could not register int for cross-interpreter sharing"); + } + // bytes if (_register_xidata(&PyBytes_Type, _bytes_shared) != 0) { Py_FatalError("could not register bytes for cross-interpreter sharing"); } + + // str + if (_register_xidata(&PyUnicode_Type, _str_shared) != 0) { + Py_FatalError("could not register str for cross-interpreter sharing"); + } } |