summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGuido van Rossum <guido@python.org>2013-12-06 20:57:40 (GMT)
committerGuido van Rossum <guido@python.org>2013-12-06 20:57:40 (GMT)
commit1a605ed5a33dbeec6b98d2a073fbbe3fcfdd84c0 (patch)
treefce1163198c3e6f531b8cd0e03413f589ada61c0
parent2f8c83568ca5850601e92e315c4a1c840e94b1cb (diff)
downloadcpython-1a605ed5a33dbeec6b98d2a073fbbe3fcfdd84c0.zip
cpython-1a605ed5a33dbeec6b98d2a073fbbe3fcfdd84c0.tar.gz
cpython-1a605ed5a33dbeec6b98d2a073fbbe3fcfdd84c0.tar.bz2
asyncio: Add Task.current_task() class method.
-rw-r--r--Lib/asyncio/tasks.py20
-rw-r--r--Lib/asyncio/test_utils.py2
-rw-r--r--Lib/test/test_asyncio/test_tasks.py36
3 files changed, 57 insertions, 1 deletions
diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py
index 999e962..cd9718f 100644
--- a/Lib/asyncio/tasks.py
+++ b/Lib/asyncio/tasks.py
@@ -122,6 +122,22 @@ class Task(futures.Future):
# Weak set containing all tasks alive.
_all_tasks = weakref.WeakSet()
+ # Dictionary containing tasks that are currently active in
+ # all running event loops. {EventLoop: Task}
+ _current_tasks = {}
+
+ @classmethod
+ def current_task(cls, loop=None):
+ """Return the currently running task in an event loop or None.
+
+ By default the current task for the current event loop is returned.
+
+ None is returned when called not in the context of a Task.
+ """
+ if loop is None:
+ loop = events.get_event_loop()
+ return cls._current_tasks.get(loop)
+
@classmethod
def all_tasks(cls, loop=None):
"""Return a set of all tasks for an event loop.
@@ -252,6 +268,8 @@ class Task(futures.Future):
self._must_cancel = False
coro = self._coro
self._fut_waiter = None
+
+ self.__class__._current_tasks[self._loop] = self
# Call either coro.throw(exc) or coro.send(value).
try:
if exc is not None:
@@ -302,6 +320,8 @@ class Task(futures.Future):
self._step, None,
RuntimeError(
'Task got bad yield: {!r}'.format(result)))
+ finally:
+ self.__class__._current_tasks.pop(self._loop)
self = None
def _wakeup(self, future):
diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py
index c26dd88..131a546 100644
--- a/Lib/asyncio/test_utils.py
+++ b/Lib/asyncio/test_utils.py
@@ -88,7 +88,7 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
class SSLWSGIServer(SilentWSGIServer):
def finish_request(self, request, client_address):
# The relative location of our test directory (which
- # contains the sample key and certificate files) differs
+ # contains the ssl key and certificate files) differs
# between the stdlib and stand-alone Tulip/asyncio.
# Prefer our own if we can find it.
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
index 8f0d081..5470da1 100644
--- a/Lib/test/test_asyncio/test_tasks.py
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -1113,6 +1113,42 @@ class TaskTests(unittest.TestCase):
self.assertEqual(res, 'test')
self.assertIsNone(t2.result())
+ def test_current_task(self):
+ self.assertIsNone(tasks.Task.current_task(loop=self.loop))
+ @tasks.coroutine
+ def coro(loop):
+ self.assertTrue(tasks.Task.current_task(loop=loop) is task)
+
+ task = tasks.Task(coro(self.loop), loop=self.loop)
+ self.loop.run_until_complete(task)
+ self.assertIsNone(tasks.Task.current_task(loop=self.loop))
+
+ def test_current_task_with_interleaving_tasks(self):
+ self.assertIsNone(tasks.Task.current_task(loop=self.loop))
+
+ fut1 = futures.Future(loop=self.loop)
+ fut2 = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def coro1(loop):
+ self.assertTrue(tasks.Task.current_task(loop=loop) is task1)
+ yield from fut1
+ self.assertTrue(tasks.Task.current_task(loop=loop) is task1)
+ fut2.set_result(True)
+
+ @tasks.coroutine
+ def coro2(loop):
+ self.assertTrue(tasks.Task.current_task(loop=loop) is task2)
+ fut1.set_result(True)
+ yield from fut2
+ self.assertTrue(tasks.Task.current_task(loop=loop) is task2)
+
+ task1 = tasks.Task(coro1(self.loop), loop=self.loop)
+ task2 = tasks.Task(coro2(self.loop), loop=self.loop)
+
+ self.loop.run_until_complete(tasks.wait((task1, task2), loop=self.loop))
+ self.assertIsNone(tasks.Task.current_task(loop=self.loop))
+
# Some thorough tests for cancellation propagation through
# coroutines, tasks and wait().