summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYury Selivanov <yury@magic.io>2018-01-23 00:11:18 (GMT)
committerGitHub <noreply@github.com>2018-01-23 00:11:18 (GMT)
commitf23746a934177c48eff754411aba54c31d6be2f0 (patch)
tree4b32964b53fa87701f71c71937792f2489b7bbb4
parent9089a265918754d95e105a7c4c409ac9352c87bb (diff)
downloadcpython-f23746a934177c48eff754411aba54c31d6be2f0.zip
cpython-f23746a934177c48eff754411aba54c31d6be2f0.tar.gz
cpython-f23746a934177c48eff754411aba54c31d6be2f0.tar.bz2
bpo-32436: Implement PEP 567 (#5027)
-rw-r--r--Include/Python.h1
-rw-r--r--Include/context.h86
-rw-r--r--Include/internal/context.h41
-rw-r--r--Include/internal/hamt.h113
-rw-r--r--Include/pystate.h8
-rw-r--r--Lib/asyncio/base_events.py21
-rw-r--r--Lib/asyncio/base_futures.py8
-rw-r--r--Lib/asyncio/events.py15
-rw-r--r--Lib/asyncio/futures.py17
-rw-r--r--Lib/asyncio/selector_events.py4
-rw-r--r--Lib/asyncio/tasks.py24
-rw-r--r--Lib/asyncio/unix_events.py2
-rw-r--r--Lib/contextvars.py4
-rw-r--r--Lib/test/test_asyncio/test_base_events.py10
-rw-r--r--Lib/test/test_asyncio/test_futures.py14
-rw-r--r--Lib/test/test_asyncio/test_tasks.py109
-rw-r--r--Lib/test/test_asyncio/utils.py8
-rw-r--r--Lib/test/test_context.py1064
-rw-r--r--Makefile.pre.in4
-rw-r--r--Misc/NEWS.d/next/Core and Builtins/2017-12-28-00-20-42.bpo-32436.H159Jv.rst1
-rw-r--r--Modules/Setup.dist1
-rw-r--r--Modules/_asynciomodule.c209
-rw-r--r--Modules/_contextvarsmodule.c75
-rw-r--r--Modules/_testcapimodule.c8
-rw-r--r--Modules/clinic/_asynciomodule.c.h29
-rw-r--r--Modules/clinic/_contextvarsmodule.c.h21
-rw-r--r--Modules/gcmodule.c2
-rw-r--r--Objects/object.c1
-rw-r--r--PCbuild/_contextvars.vcxproj77
-rw-r--r--PCbuild/_contextvars.vcxproj.filters16
-rw-r--r--PCbuild/_decimal.vcxproj2
-rw-r--r--PCbuild/pcbuild.proj2
-rw-r--r--PCbuild/pythoncore.vcxproj6
-rw-r--r--PCbuild/pythoncore.vcxproj.filters18
-rw-r--r--Python/clinic/context.c.h146
-rw-r--r--Python/context.c1220
-rw-r--r--Python/hamt.c2982
-rw-r--r--Python/pylifecycle.c6
-rw-r--r--Python/pystate.c9
-rw-r--r--Tools/msi/lib/lib_files.wxs2
-rw-r--r--setup.py3
41 files changed, 6269 insertions, 120 deletions
diff --git a/Include/Python.h b/Include/Python.h
index 20051e7..dd595ea 100644
--- a/Include/Python.h
+++ b/Include/Python.h
@@ -109,6 +109,7 @@
#include "pyerrors.h"
#include "pystate.h"
+#include "context.h"
#include "pyarena.h"
#include "modsupport.h"
diff --git a/Include/context.h b/Include/context.h
new file mode 100644
index 0000000..f872dce
--- /dev/null
+++ b/Include/context.h
@@ -0,0 +1,86 @@
+#ifndef Py_CONTEXT_H
+#define Py_CONTEXT_H
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#ifndef Py_LIMITED_API
+
+
+PyAPI_DATA(PyTypeObject) PyContext_Type;
+typedef struct _pycontextobject PyContext;
+
+PyAPI_DATA(PyTypeObject) PyContextVar_Type;
+typedef struct _pycontextvarobject PyContextVar;
+
+PyAPI_DATA(PyTypeObject) PyContextToken_Type;
+typedef struct _pycontexttokenobject PyContextToken;
+
+
+#define PyContext_CheckExact(o) (Py_TYPE(o) == &PyContext_Type)
+#define PyContextVar_CheckExact(o) (Py_TYPE(o) == &PyContextVar_Type)
+#define PyContextToken_CheckExact(o) (Py_TYPE(o) == &PyContextToken_Type)
+
+
+PyAPI_FUNC(PyContext *) PyContext_New(void);
+PyAPI_FUNC(PyContext *) PyContext_Copy(PyContext *);
+PyAPI_FUNC(PyContext *) PyContext_CopyCurrent(void);
+
+PyAPI_FUNC(int) PyContext_Enter(PyContext *);
+PyAPI_FUNC(int) PyContext_Exit(PyContext *);
+
+
+/* Create a new context variable.
+
+ default_value can be NULL.
+*/
+PyAPI_FUNC(PyContextVar *) PyContextVar_New(
+ const char *name, PyObject *default_value);
+
+
+/* Get a value for the variable.
+
+ Returns -1 if an error occurred during lookup.
+
+ Returns 0 if value either was or was not found.
+
+ If value was found, *value will point to it.
+ If not, it will point to:
+
+ - default_value, if not NULL;
+ - the default value of "var", if not NULL;
+ - NULL.
+
+ '*value' will be a new ref, if not NULL.
+*/
+PyAPI_FUNC(int) PyContextVar_Get(
+ PyContextVar *var, PyObject *default_value, PyObject **value);
+
+
+/* Set a new value for the variable.
+ Returns NULL if an error occurs.
+*/
+PyAPI_FUNC(PyContextToken *) PyContextVar_Set(
+ PyContextVar *var, PyObject *value);
+
+
+/* Reset a variable to its previous value.
+ Returns 0 on sucess, -1 on error.
+*/
+PyAPI_FUNC(int) PyContextVar_Reset(
+ PyContextVar *var, PyContextToken *token);
+
+
+/* This method is exposed only for CPython tests. Don not use it. */
+PyAPI_FUNC(PyObject *) _PyContext_NewHamtForTests(void);
+
+
+PyAPI_FUNC(int) PyContext_ClearFreeList(void);
+
+
+#endif /* !Py_LIMITED_API */
+
+#ifdef __cplusplus
+}
+#endif
+#endif /* !Py_CONTEXT_H */
diff --git a/Include/internal/context.h b/Include/internal/context.h
new file mode 100644
index 0000000..59f88f2
--- /dev/null
+++ b/Include/internal/context.h
@@ -0,0 +1,41 @@
+#ifndef Py_INTERNAL_CONTEXT_H
+#define Py_INTERNAL_CONTEXT_H
+
+
+#include "internal/hamt.h"
+
+
+struct _pycontextobject {
+ PyObject_HEAD
+ PyContext *ctx_prev;
+ PyHamtObject *ctx_vars;
+ PyObject *ctx_weakreflist;
+ int ctx_entered;
+};
+
+
+struct _pycontextvarobject {
+ PyObject_HEAD
+ PyObject *var_name;
+ PyObject *var_default;
+ PyObject *var_cached;
+ uint64_t var_cached_tsid;
+ uint64_t var_cached_tsver;
+ Py_hash_t var_hash;
+};
+
+
+struct _pycontexttokenobject {
+ PyObject_HEAD
+ PyContext *tok_ctx;
+ PyContextVar *tok_var;
+ PyObject *tok_oldval;
+ int tok_used;
+};
+
+
+int _PyContext_Init(void);
+void _PyContext_Fini(void);
+
+
+#endif /* !Py_INTERNAL_CONTEXT_H */
diff --git a/Include/internal/hamt.h b/Include/internal/hamt.h
new file mode 100644
index 0000000..52488d0
--- /dev/null
+++ b/Include/internal/hamt.h
@@ -0,0 +1,113 @@
+#ifndef Py_INTERNAL_HAMT_H
+#define Py_INTERNAL_HAMT_H
+
+
+#define _Py_HAMT_MAX_TREE_DEPTH 7
+
+
+#define PyHamt_Check(o) (Py_TYPE(o) == &_PyHamt_Type)
+
+
+/* Abstract tree node. */
+typedef struct {
+ PyObject_HEAD
+} PyHamtNode;
+
+
+/* An HAMT immutable mapping collection. */
+typedef struct {
+ PyObject_HEAD
+ PyHamtNode *h_root;
+ PyObject *h_weakreflist;
+ Py_ssize_t h_count;
+} PyHamtObject;
+
+
+/* A struct to hold the state of depth-first traverse of the tree.
+
+ HAMT is an immutable collection. Iterators will hold a strong reference
+ to it, and every node in the HAMT has strong references to its children.
+
+ So for iterators, we can implement zero allocations and zero reference
+ inc/dec depth-first iteration.
+
+ - i_nodes: an array of seven pointers to tree nodes
+ - i_level: the current node in i_nodes
+ - i_pos: an array of positions within nodes in i_nodes.
+*/
+typedef struct {
+ PyHamtNode *i_nodes[_Py_HAMT_MAX_TREE_DEPTH];
+ Py_ssize_t i_pos[_Py_HAMT_MAX_TREE_DEPTH];
+ int8_t i_level;
+} PyHamtIteratorState;
+
+
+/* Base iterator object.
+
+ Contains the iteration state, a pointer to the HAMT tree,
+ and a pointer to the 'yield function'. The latter is a simple
+ function that returns a key/value tuple for the 'Items' iterator,
+ just a key for the 'Keys' iterator, and a value for the 'Values'
+ iterator.
+*/
+typedef struct {
+ PyObject_HEAD
+ PyHamtObject *hi_obj;
+ PyHamtIteratorState hi_iter;
+ binaryfunc hi_yield;
+} PyHamtIterator;
+
+
+PyAPI_DATA(PyTypeObject) _PyHamt_Type;
+PyAPI_DATA(PyTypeObject) _PyHamt_ArrayNode_Type;
+PyAPI_DATA(PyTypeObject) _PyHamt_BitmapNode_Type;
+PyAPI_DATA(PyTypeObject) _PyHamt_CollisionNode_Type;
+PyAPI_DATA(PyTypeObject) _PyHamtKeys_Type;
+PyAPI_DATA(PyTypeObject) _PyHamtValues_Type;
+PyAPI_DATA(PyTypeObject) _PyHamtItems_Type;
+
+
+/* Create a new HAMT immutable mapping. */
+PyHamtObject * _PyHamt_New(void);
+
+/* Return a new collection based on "o", but with an additional
+ key/val pair. */
+PyHamtObject * _PyHamt_Assoc(PyHamtObject *o, PyObject *key, PyObject *val);
+
+/* Return a new collection based on "o", but without "key". */
+PyHamtObject * _PyHamt_Without(PyHamtObject *o, PyObject *key);
+
+/* Find "key" in the "o" collection.
+
+ Return:
+ - -1: An error ocurred.
+ - 0: "key" wasn't found in "o".
+ - 1: "key" is in "o"; "*val" is set to its value (a borrowed ref).
+*/
+int _PyHamt_Find(PyHamtObject *o, PyObject *key, PyObject **val);
+
+/* Check if "v" is equal to "w".
+
+ Return:
+ - 0: v != w
+ - 1: v == w
+ - -1: An error occurred.
+*/
+int _PyHamt_Eq(PyHamtObject *v, PyHamtObject *w);
+
+/* Return the size of "o"; equivalent of "len(o)". */
+Py_ssize_t _PyHamt_Len(PyHamtObject *o);
+
+/* Return a Keys iterator over "o". */
+PyObject * _PyHamt_NewIterKeys(PyHamtObject *o);
+
+/* Return a Values iterator over "o". */
+PyObject * _PyHamt_NewIterValues(PyHamtObject *o);
+
+/* Return a Items iterator over "o". */
+PyObject * _PyHamt_NewIterItems(PyHamtObject *o);
+
+int _PyHamt_Init(void);
+void _PyHamt_Fini(void);
+
+#endif /* !Py_INTERNAL_HAMT_H */
diff --git a/Include/pystate.h b/Include/pystate.h
index 5a69e14..d004be5 100644
--- a/Include/pystate.h
+++ b/Include/pystate.h
@@ -143,6 +143,8 @@ typedef struct _is {
/* AtExit module */
void (*pyexitfunc)(PyObject *);
PyObject *pyexitmodule;
+
+ uint64_t tstate_next_unique_id;
} PyInterpreterState;
#endif /* !Py_LIMITED_API */
@@ -270,6 +272,12 @@ typedef struct _ts {
PyObject *async_gen_firstiter;
PyObject *async_gen_finalizer;
+ PyObject *context;
+ uint64_t context_ver;
+
+ /* Unique thread state id. */
+ uint64_t id;
+
/* XXX signal handlers should also be here */
} PyThreadState;
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index ca9eee7..e722cf2 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -489,7 +489,7 @@ class BaseEventLoop(events.AbstractEventLoop):
"""
return time.monotonic()
- def call_later(self, delay, callback, *args):
+ def call_later(self, delay, callback, *args, context=None):
"""Arrange for a callback to be called at a given time.
Return a Handle: an opaque object with a cancel() method that
@@ -505,12 +505,13 @@ class BaseEventLoop(events.AbstractEventLoop):
Any positional arguments after the callback will be passed to
the callback when it is called.
"""
- timer = self.call_at(self.time() + delay, callback, *args)
+ timer = self.call_at(self.time() + delay, callback, *args,
+ context=context)
if timer._source_traceback:
del timer._source_traceback[-1]
return timer
- def call_at(self, when, callback, *args):
+ def call_at(self, when, callback, *args, context=None):
"""Like call_later(), but uses an absolute time.
Absolute time corresponds to the event loop's time() method.
@@ -519,14 +520,14 @@ class BaseEventLoop(events.AbstractEventLoop):
if self._debug:
self._check_thread()
self._check_callback(callback, 'call_at')
- timer = events.TimerHandle(when, callback, args, self)
+ timer = events.TimerHandle(when, callback, args, self, context)
if timer._source_traceback:
del timer._source_traceback[-1]
heapq.heappush(self._scheduled, timer)
timer._scheduled = True
return timer
- def call_soon(self, callback, *args):
+ def call_soon(self, callback, *args, context=None):
"""Arrange for a callback to be called as soon as possible.
This operates as a FIFO queue: callbacks are called in the
@@ -540,7 +541,7 @@ class BaseEventLoop(events.AbstractEventLoop):
if self._debug:
self._check_thread()
self._check_callback(callback, 'call_soon')
- handle = self._call_soon(callback, args)
+ handle = self._call_soon(callback, args, context)
if handle._source_traceback:
del handle._source_traceback[-1]
return handle
@@ -555,8 +556,8 @@ class BaseEventLoop(events.AbstractEventLoop):
f'a callable object was expected by {method}(), '
f'got {callback!r}')
- def _call_soon(self, callback, args):
- handle = events.Handle(callback, args, self)
+ def _call_soon(self, callback, args, context):
+ handle = events.Handle(callback, args, self, context)
if handle._source_traceback:
del handle._source_traceback[-1]
self._ready.append(handle)
@@ -579,12 +580,12 @@ class BaseEventLoop(events.AbstractEventLoop):
"Non-thread-safe operation invoked on an event loop other "
"than the current one")
- def call_soon_threadsafe(self, callback, *args):
+ def call_soon_threadsafe(self, callback, *args, context=None):
"""Like call_soon(), but thread-safe."""
self._check_closed()
if self._debug:
self._check_callback(callback, 'call_soon_threadsafe')
- handle = self._call_soon(callback, args)
+ handle = self._call_soon(callback, args, context)
if handle._source_traceback:
del handle._source_traceback[-1]
self._write_to_self()
diff --git a/Lib/asyncio/base_futures.py b/Lib/asyncio/base_futures.py
index 008812e..5182884 100644
--- a/Lib/asyncio/base_futures.py
+++ b/Lib/asyncio/base_futures.py
@@ -41,13 +41,13 @@ def _format_callbacks(cb):
return format_helpers._format_callback_source(callback, ())
if size == 1:
- cb = format_cb(cb[0])
+ cb = format_cb(cb[0][0])
elif size == 2:
- cb = '{}, {}'.format(format_cb(cb[0]), format_cb(cb[1]))
+ cb = '{}, {}'.format(format_cb(cb[0][0]), format_cb(cb[1][0]))
elif size > 2:
- cb = '{}, <{} more>, {}'.format(format_cb(cb[0]),
+ cb = '{}, <{} more>, {}'.format(format_cb(cb[0][0]),
size - 2,
- format_cb(cb[-1]))
+ format_cb(cb[-1][0]))
return f'cb=[{cb}]'
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index d5365dc..5c68d4c 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -11,6 +11,7 @@ __all__ = (
'_get_running_loop',
)
+import contextvars
import os
import socket
import subprocess
@@ -32,9 +33,13 @@ class Handle:
"""Object returned by callback registration methods."""
__slots__ = ('_callback', '_args', '_cancelled', '_loop',
- '_source_traceback', '_repr', '__weakref__')
+ '_source_traceback', '_repr', '__weakref__',
+ '_context')
- def __init__(self, callback, args, loop):
+ def __init__(self, callback, args, loop, context=None):
+ if context is None:
+ context = contextvars.copy_context()
+ self._context = context
self._loop = loop
self._callback = callback
self._args = args
@@ -80,7 +85,7 @@ class Handle:
def _run(self):
try:
- self._callback(*self._args)
+ self._context.run(self._callback, *self._args)
except Exception as exc:
cb = format_helpers._format_callback_source(
self._callback, self._args)
@@ -101,9 +106,9 @@ class TimerHandle(Handle):
__slots__ = ['_scheduled', '_when']
- def __init__(self, when, callback, args, loop):
+ def __init__(self, when, callback, args, loop, context=None):
assert when is not None
- super().__init__(callback, args, loop)
+ super().__init__(callback, args, loop, context)
if self._source_traceback:
del self._source_traceback[-1]
self._when = when
diff --git a/Lib/asyncio/futures.py b/Lib/asyncio/futures.py
index 1c05b22..59621ff 100644
--- a/Lib/asyncio/futures.py
+++ b/Lib/asyncio/futures.py
@@ -6,6 +6,7 @@ __all__ = (
)
import concurrent.futures
+import contextvars
import logging
import sys
@@ -144,8 +145,8 @@ class Future:
return
self._callbacks[:] = []
- for callback in callbacks:
- self._loop.call_soon(callback, self)
+ for callback, ctx in callbacks:
+ self._loop.call_soon(callback, self, context=ctx)
def cancelled(self):
"""Return True if the future was cancelled."""
@@ -192,7 +193,7 @@ class Future:
self.__log_traceback = False
return self._exception
- def add_done_callback(self, fn):
+ def add_done_callback(self, fn, *, context=None):
"""Add a callback to be run when the future becomes done.
The callback is called with a single argument - the future object. If
@@ -200,9 +201,11 @@ class Future:
scheduled with call_soon.
"""
if self._state != _PENDING:
- self._loop.call_soon(fn, self)
+ self._loop.call_soon(fn, self, context=context)
else:
- self._callbacks.append(fn)
+ if context is None:
+ context = contextvars.copy_context()
+ self._callbacks.append((fn, context))
# New method not in PEP 3148.
@@ -211,7 +214,9 @@ class Future:
Returns the number of callbacks removed.
"""
- filtered_callbacks = [f for f in self._callbacks if f != fn]
+ filtered_callbacks = [(f, ctx)
+ for (f, ctx) in self._callbacks
+ if f != fn]
removed_count = len(self._callbacks) - len(filtered_callbacks)
if removed_count:
self._callbacks[:] = filtered_callbacks
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
index 5692e38..9446ae6 100644
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -256,7 +256,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
def _add_reader(self, fd, callback, *args):
self._check_closed()
- handle = events.Handle(callback, args, self)
+ handle = events.Handle(callback, args, self, None)
try:
key = self._selector.get_key(fd)
except KeyError:
@@ -292,7 +292,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
def _add_writer(self, fd, callback, *args):
self._check_closed()
- handle = events.Handle(callback, args, self)
+ handle = events.Handle(callback, args, self, None)
try:
key = self._selector.get_key(fd)
except KeyError:
diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py
index b118088..609b8e8 100644
--- a/Lib/asyncio/tasks.py
+++ b/Lib/asyncio/tasks.py
@@ -10,6 +10,7 @@ __all__ = (
)
import concurrent.futures
+import contextvars
import functools
import inspect
import types
@@ -96,8 +97,9 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
self._must_cancel = False
self._fut_waiter = None
self._coro = coro
+ self._context = contextvars.copy_context()
- self._loop.call_soon(self._step)
+ self._loop.call_soon(self._step, context=self._context)
_register_task(self)
def __del__(self):
@@ -229,15 +231,18 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
new_exc = RuntimeError(
f'Task {self!r} got Future '
f'{result!r} attached to a different loop')
- self._loop.call_soon(self._step, new_exc)
+ self._loop.call_soon(
+ self._step, new_exc, context=self._context)
elif blocking:
if result is self:
new_exc = RuntimeError(
f'Task cannot await on itself: {self!r}')
- self._loop.call_soon(self._step, new_exc)
+ self._loop.call_soon(
+ self._step, new_exc, context=self._context)
else:
result._asyncio_future_blocking = False
- result.add_done_callback(self._wakeup)
+ result.add_done_callback(
+ self._wakeup, context=self._context)
self._fut_waiter = result
if self._must_cancel:
if self._fut_waiter.cancel():
@@ -246,21 +251,24 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
new_exc = RuntimeError(
f'yield was used instead of yield from '
f'in task {self!r} with {result!r}')
- self._loop.call_soon(self._step, new_exc)
+ self._loop.call_soon(
+ self._step, new_exc, context=self._context)
elif result is None:
# Bare yield relinquishes control for one event loop iteration.
- self._loop.call_soon(self._step)
+ self._loop.call_soon(self._step, context=self._context)
elif inspect.isgenerator(result):
# Yielding a generator is just wrong.
new_exc = RuntimeError(
f'yield was used instead of yield from for '
f'generator in task {self!r} with {result}')
- self._loop.call_soon(self._step, new_exc)
+ self._loop.call_soon(
+ self._step, new_exc, context=self._context)
else:
# Yielding something else is an error.
new_exc = RuntimeError(f'Task got bad yield: {result!r}')
- self._loop.call_soon(self._step, new_exc)
+ self._loop.call_soon(
+ self._step, new_exc, context=self._context)
finally:
_leave_task(self._loop, self)
self = None # Needed to break cycles when an exception occurs.
diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py
index 028a0ca..9b9d004 100644
--- a/Lib/asyncio/unix_events.py
+++ b/Lib/asyncio/unix_events.py
@@ -92,7 +92,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
except (ValueError, OSError) as exc:
raise RuntimeError(str(exc))
- handle = events.Handle(callback, args, self)
+ handle = events.Handle(callback, args, self, None)
self._signal_handlers[sig] = handle
try:
diff --git a/Lib/contextvars.py b/Lib/contextvars.py
new file mode 100644
index 0000000..d78c80d
--- /dev/null
+++ b/Lib/contextvars.py
@@ -0,0 +1,4 @@
+from _contextvars import Context, ContextVar, Token, copy_context
+
+
+__all__ = ('Context', 'ContextVar', 'Token', 'copy_context')
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
index fc3b810..8d72df6 100644
--- a/Lib/test/test_asyncio/test_base_events.py
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -192,14 +192,14 @@ class BaseEventLoopTests(test_utils.TestCase):
self.assertRaises(RuntimeError, self.loop.run_until_complete, f)
def test__add_callback_handle(self):
- h = asyncio.Handle(lambda: False, (), self.loop)
+ h = asyncio.Handle(lambda: False, (), self.loop, None)
self.loop._add_callback(h)
self.assertFalse(self.loop._scheduled)
self.assertIn(h, self.loop._ready)
def test__add_callback_cancelled_handle(self):
- h = asyncio.Handle(lambda: False, (), self.loop)
+ h = asyncio.Handle(lambda: False, (), self.loop, None)
h.cancel()
self.loop._add_callback(h)
@@ -333,9 +333,9 @@ class BaseEventLoopTests(test_utils.TestCase):
def test__run_once(self):
h1 = asyncio.TimerHandle(time.monotonic() + 5.0, lambda: True, (),
- self.loop)
+ self.loop, None)
h2 = asyncio.TimerHandle(time.monotonic() + 10.0, lambda: True, (),
- self.loop)
+ self.loop, None)
h1.cancel()
@@ -390,7 +390,7 @@ class BaseEventLoopTests(test_utils.TestCase):
handle = loop.call_soon(lambda: True)
h = asyncio.TimerHandle(time.monotonic() - 1, cb, (self.loop,),
- self.loop)
+ self.loop, None)
self.loop._process_events = mock.Mock()
self.loop._scheduled.append(h)
diff --git a/Lib/test/test_asyncio/test_futures.py b/Lib/test/test_asyncio/test_futures.py
index ab45ee3..37f4c65 100644
--- a/Lib/test/test_asyncio/test_futures.py
+++ b/Lib/test/test_asyncio/test_futures.py
@@ -565,16 +565,22 @@ class BaseFutureTests:
@unittest.skipUnless(hasattr(futures, '_CFuture'),
'requires the C _asyncio module')
class CFutureTests(BaseFutureTests, test_utils.TestCase):
- cls = futures._CFuture
+ try:
+ cls = futures._CFuture
+ except AttributeError:
+ cls = None
@unittest.skipUnless(hasattr(futures, '_CFuture'),
'requires the C _asyncio module')
class CSubFutureTests(BaseFutureTests, test_utils.TestCase):
- class CSubFuture(futures._CFuture):
- pass
+ try:
+ class CSubFuture(futures._CFuture):
+ pass
- cls = CSubFuture
+ cls = CSubFuture
+ except AttributeError:
+ cls = None
class PyFutureTests(BaseFutureTests, test_utils.TestCase):
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
index 26e4f64..96d2658 100644
--- a/Lib/test/test_asyncio/test_tasks.py
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -2,10 +2,11 @@
import collections
import contextlib
+import contextvars
import functools
import gc
import io
-import os
+import random
import re
import sys
import types
@@ -1377,9 +1378,9 @@ class BaseTaskTests:
self.cb_added = False
super().__init__(*args, **kwds)
- def add_done_callback(self, fn):
+ def add_done_callback(self, *args, **kwargs):
self.cb_added = True
- super().add_done_callback(fn)
+ super().add_done_callback(*args, **kwargs)
fut = Fut(loop=self.loop)
result = None
@@ -2091,7 +2092,7 @@ class BaseTaskTests:
@mock.patch('asyncio.base_events.logger')
def test_error_in_call_soon(self, m_log):
- def call_soon(callback, *args):
+ def call_soon(callback, *args, **kwargs):
raise ValueError
self.loop.call_soon = call_soon
@@ -2176,6 +2177,91 @@ class BaseTaskTests:
self.loop.run_until_complete(coro())
+ def test_context_1(self):
+ cvar = contextvars.ContextVar('cvar', default='nope')
+
+ async def sub():
+ await asyncio.sleep(0.01, loop=loop)
+ self.assertEqual(cvar.get(), 'nope')
+ cvar.set('something else')
+
+ async def main():
+ self.assertEqual(cvar.get(), 'nope')
+ subtask = self.new_task(loop, sub())
+ cvar.set('yes')
+ self.assertEqual(cvar.get(), 'yes')
+ await subtask
+ self.assertEqual(cvar.get(), 'yes')
+
+ loop = asyncio.new_event_loop()
+ try:
+ task = self.new_task(loop, main())
+ loop.run_until_complete(task)
+ finally:
+ loop.close()
+
+ def test_context_2(self):
+ cvar = contextvars.ContextVar('cvar', default='nope')
+
+ async def main():
+ def fut_on_done(fut):
+ # This change must not pollute the context
+ # of the "main()" task.
+ cvar.set('something else')
+
+ self.assertEqual(cvar.get(), 'nope')
+
+ for j in range(2):
+ fut = self.new_future(loop)
+ fut.add_done_callback(fut_on_done)
+ cvar.set(f'yes{j}')
+ loop.call_soon(fut.set_result, None)
+ await fut
+ self.assertEqual(cvar.get(), f'yes{j}')
+
+ for i in range(3):
+ # Test that task passed its context to add_done_callback:
+ cvar.set(f'yes{i}-{j}')
+ await asyncio.sleep(0.001, loop=loop)
+ self.assertEqual(cvar.get(), f'yes{i}-{j}')
+
+ loop = asyncio.new_event_loop()
+ try:
+ task = self.new_task(loop, main())
+ loop.run_until_complete(task)
+ finally:
+ loop.close()
+
+ self.assertEqual(cvar.get(), 'nope')
+
+ def test_context_3(self):
+ # Run 100 Tasks in parallel, each modifying cvar.
+
+ cvar = contextvars.ContextVar('cvar', default=-1)
+
+ async def sub(num):
+ for i in range(10):
+ cvar.set(num + i)
+ await asyncio.sleep(
+ random.uniform(0.001, 0.05), loop=loop)
+ self.assertEqual(cvar.get(), num + i)
+
+ async def main():
+ tasks = []
+ for i in range(100):
+ task = loop.create_task(sub(random.randint(0, 10)))
+ tasks.append(task)
+
+ await asyncio.gather(*tasks, loop=loop)
+
+ loop = asyncio.new_event_loop()
+ try:
+ loop.run_until_complete(main())
+ finally:
+ loop.close()
+
+ self.assertEqual(cvar.get(), -1)
+
def add_subclass_tests(cls):
BaseTask = cls.Task
@@ -2193,9 +2279,9 @@ def add_subclass_tests(cls):
self.calls['_schedule_callbacks'] += 1
return super()._schedule_callbacks()
- def add_done_callback(self, *args):
+ def add_done_callback(self, *args, **kwargs):
self.calls['add_done_callback'] += 1
- return super().add_done_callback(*args)
+ return super().add_done_callback(*args, **kwargs)
class Task(CommonFuture, BaseTask):
def _step(self, *args):
@@ -2486,10 +2572,13 @@ class PyIntrospectionTests(unittest.TestCase, BaseTaskIntrospectionTests):
@unittest.skipUnless(hasattr(tasks, '_c_register_task'),
'requires the C _asyncio module')
class CIntrospectionTests(unittest.TestCase, BaseTaskIntrospectionTests):
- _register_task = staticmethod(tasks._c_register_task)
- _unregister_task = staticmethod(tasks._c_unregister_task)
- _enter_task = staticmethod(tasks._c_enter_task)
- _leave_task = staticmethod(tasks._c_leave_task)
+ if hasattr(tasks, '_c_register_task'):
+ _register_task = staticmethod(tasks._c_register_task)
+ _unregister_task = staticmethod(tasks._c_unregister_task)
+ _enter_task = staticmethod(tasks._c_enter_task)
+ _leave_task = staticmethod(tasks._c_leave_task)
+ else:
+ _register_task = _unregister_task = _enter_task = _leave_task = None
class BaseCurrentLoopTests:
diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py
index f756ec9..96dfe2f 100644
--- a/Lib/test/test_asyncio/utils.py
+++ b/Lib/test/test_asyncio/utils.py
@@ -365,7 +365,7 @@ class TestLoop(base_events.BaseEventLoop):
raise AssertionError("Time generator is not finished")
def _add_reader(self, fd, callback, *args):
- self.readers[fd] = events.Handle(callback, args, self)
+ self.readers[fd] = events.Handle(callback, args, self, None)
def _remove_reader(self, fd):
self.remove_reader_count[fd] += 1
@@ -391,7 +391,7 @@ class TestLoop(base_events.BaseEventLoop):
raise AssertionError(f'fd {fd} is registered')
def _add_writer(self, fd, callback, *args):
- self.writers[fd] = events.Handle(callback, args, self)
+ self.writers[fd] = events.Handle(callback, args, self, None)
def _remove_writer(self, fd):
self.remove_writer_count[fd] += 1
@@ -457,9 +457,9 @@ class TestLoop(base_events.BaseEventLoop):
self.advance_time(advance)
self._timers = []
- def call_at(self, when, callback, *args):
+ def call_at(self, when, callback, *args, context=None):
self._timers.append(when)
- return super().call_at(when, callback, *args)
+ return super().call_at(when, callback, *args, context=context)
def _process_events(self, event_list):
return
diff --git a/Lib/test/test_context.py b/Lib/test/test_context.py
new file mode 100644
index 0000000..74d05fc
--- /dev/null
+++ b/Lib/test/test_context.py
@@ -0,0 +1,1064 @@
+import concurrent.futures
+import contextvars
+import functools
+import gc
+import random
+import time
+import unittest
+import weakref
+
+try:
+ from _testcapi import hamt
+except ImportError:
+ hamt = None
+
+
+def isolated_context(func):
+ """Needed to make reftracking test mode work."""
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ ctx = contextvars.Context()
+ return ctx.run(func, *args, **kwargs)
+ return wrapper
+
+
+class ContextTest(unittest.TestCase):
+ def test_context_var_new_1(self):
+ with self.assertRaisesRegex(TypeError, 'takes exactly 1'):
+ contextvars.ContextVar()
+
+ with self.assertRaisesRegex(TypeError, 'must be a str'):
+ contextvars.ContextVar(1)
+
+ c = contextvars.ContextVar('a')
+ self.assertNotEqual(hash(c), hash('a'))
+
+ def test_context_var_new_2(self):
+ self.assertIsNone(contextvars.ContextVar[int])
+
+ @isolated_context
+ def test_context_var_repr_1(self):
+ c = contextvars.ContextVar('a')
+ self.assertIn('a', repr(c))
+
+ c = contextvars.ContextVar('a', default=123)
+ self.assertIn('123', repr(c))
+
+ lst = []
+ c = contextvars.ContextVar('a', default=lst)
+ lst.append(c)
+ self.assertIn('...', repr(c))
+ self.assertIn('...', repr(lst))
+
+ t = c.set(1)
+ self.assertIn(repr(c), repr(t))
+ self.assertNotIn(' used ', repr(t))
+ c.reset(t)
+ self.assertIn(' used ', repr(t))
+
+ def test_context_subclassing_1(self):
+ with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
+ class MyContextVar(contextvars.ContextVar):
+ # Potentially we might want ContextVars to be subclassable.
+ pass
+
+ with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
+ class MyContext(contextvars.Context):
+ pass
+
+ with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
+ class MyToken(contextvars.Token):
+ pass
+
+ def test_context_new_1(self):
+ with self.assertRaisesRegex(TypeError, 'any arguments'):
+ contextvars.Context(1)
+ with self.assertRaisesRegex(TypeError, 'any arguments'):
+ contextvars.Context(1, a=1)
+ with self.assertRaisesRegex(TypeError, 'any arguments'):
+ contextvars.Context(a=1)
+ contextvars.Context(**{})
+
+ def test_context_typerrors_1(self):
+ ctx = contextvars.Context()
+
+ with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
+ ctx[1]
+ with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
+ 1 in ctx
+ with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
+ ctx.get(1)
+
+ def test_context_get_context_1(self):
+ ctx = contextvars.copy_context()
+ self.assertIsInstance(ctx, contextvars.Context)
+
+ def test_context_run_1(self):
+ ctx = contextvars.Context()
+
+ with self.assertRaisesRegex(TypeError, 'missing 1 required'):
+ ctx.run()
+
+ def test_context_run_2(self):
+ ctx = contextvars.Context()
+
+ def func(*args, **kwargs):
+ kwargs['spam'] = 'foo'
+ args += ('bar',)
+ return args, kwargs
+
+ for f in (func, functools.partial(func)):
+ # partial doesn't support FASTCALL
+
+ self.assertEqual(ctx.run(f), (('bar',), {'spam': 'foo'}))
+ self.assertEqual(ctx.run(f, 1), ((1, 'bar'), {'spam': 'foo'}))
+
+ self.assertEqual(
+ ctx.run(f, a=2),
+ (('bar',), {'a': 2, 'spam': 'foo'}))
+
+ self.assertEqual(
+ ctx.run(f, 11, a=2),
+ ((11, 'bar'), {'a': 2, 'spam': 'foo'}))
+
+ a = {}
+ self.assertEqual(
+ ctx.run(f, 11, **a),
+ ((11, 'bar'), {'spam': 'foo'}))
+ self.assertEqual(a, {})
+
+ def test_context_run_3(self):
+ ctx = contextvars.Context()
+
+ def func(*args, **kwargs):
+ 1 / 0
+
+ with self.assertRaises(ZeroDivisionError):
+ ctx.run(func)
+ with self.assertRaises(ZeroDivisionError):
+ ctx.run(func, 1, 2)
+ with self.assertRaises(ZeroDivisionError):
+ ctx.run(func, 1, 2, a=123)
+
+ @isolated_context
+ def test_context_run_4(self):
+ ctx1 = contextvars.Context()
+ ctx2 = contextvars.Context()
+ var = contextvars.ContextVar('var')
+
+ def func2():
+ self.assertIsNone(var.get(None))
+
+ def func1():
+ self.assertIsNone(var.get(None))
+ var.set('spam')
+ ctx2.run(func2)
+ self.assertEqual(var.get(None), 'spam')
+
+ cur = contextvars.copy_context()
+ self.assertEqual(len(cur), 1)
+ self.assertEqual(cur[var], 'spam')
+ return cur
+
+ returned_ctx = ctx1.run(func1)
+ self.assertEqual(ctx1, returned_ctx)
+ self.assertEqual(returned_ctx[var], 'spam')
+ self.assertIn(var, returned_ctx)
+
+ def test_context_run_5(self):
+ ctx = contextvars.Context()
+ var = contextvars.ContextVar('var')
+
+ def func():
+ self.assertIsNone(var.get(None))
+ var.set('spam')
+ 1 / 0
+
+ with self.assertRaises(ZeroDivisionError):
+ ctx.run(func)
+
+ self.assertIsNone(var.get(None))
+
+ def test_context_run_6(self):
+ ctx = contextvars.Context()
+ c = contextvars.ContextVar('a', default=0)
+
+ def fun():
+ self.assertEqual(c.get(), 0)
+ self.assertIsNone(ctx.get(c))
+
+ c.set(42)
+ self.assertEqual(c.get(), 42)
+ self.assertEqual(ctx.get(c), 42)
+
+ ctx.run(fun)
+
+ def test_context_run_7(self):
+ ctx = contextvars.Context()
+
+ def fun():
+ with self.assertRaisesRegex(RuntimeError, 'is already entered'):
+ ctx.run(fun)
+
+ ctx.run(fun)
+
+ @isolated_context
+ def test_context_getset_1(self):
+ c = contextvars.ContextVar('c')
+ with self.assertRaises(LookupError):
+ c.get()
+
+ self.assertIsNone(c.get(None))
+
+ t0 = c.set(42)
+ self.assertEqual(c.get(), 42)
+ self.assertEqual(c.get(None), 42)
+ self.assertIs(t0.old_value, t0.MISSING)
+ self.assertIs(t0.old_value, contextvars.Token.MISSING)
+ self.assertIs(t0.var, c)
+
+ t = c.set('spam')
+ self.assertEqual(c.get(), 'spam')
+ self.assertEqual(c.get(None), 'spam')
+ self.assertEqual(t.old_value, 42)
+ c.reset(t)
+
+ self.assertEqual(c.get(), 42)
+ self.assertEqual(c.get(None), 42)
+
+ c.set('spam2')
+ with self.assertRaisesRegex(RuntimeError, 'has already been used'):
+ c.reset(t)
+ self.assertEqual(c.get(), 'spam2')
+
+ ctx1 = contextvars.copy_context()
+ self.assertIn(c, ctx1)
+
+ c.reset(t0)
+ with self.assertRaisesRegex(RuntimeError, 'has already been used'):
+ c.reset(t0)
+ self.assertIsNone(c.get(None))
+
+ self.assertIn(c, ctx1)
+ self.assertEqual(ctx1[c], 'spam2')
+ self.assertEqual(ctx1.get(c, 'aa'), 'spam2')
+ self.assertEqual(len(ctx1), 1)
+ self.assertEqual(list(ctx1.items()), [(c, 'spam2')])
+ self.assertEqual(list(ctx1.values()), ['spam2'])
+ self.assertEqual(list(ctx1.keys()), [c])
+ self.assertEqual(list(ctx1), [c])
+
+ ctx2 = contextvars.copy_context()
+ self.assertNotIn(c, ctx2)
+ with self.assertRaises(KeyError):
+ ctx2[c]
+ self.assertEqual(ctx2.get(c, 'aa'), 'aa')
+ self.assertEqual(len(ctx2), 0)
+ self.assertEqual(list(ctx2), [])
+
+ @isolated_context
+ def test_context_getset_2(self):
+ v1 = contextvars.ContextVar('v1')
+ v2 = contextvars.ContextVar('v2')
+
+ t1 = v1.set(42)
+ with self.assertRaisesRegex(ValueError, 'by a different'):
+ v2.reset(t1)
+
+ @isolated_context
+ def test_context_getset_3(self):
+ c = contextvars.ContextVar('c', default=42)
+ ctx = contextvars.Context()
+
+ def fun():
+ self.assertEqual(c.get(), 42)
+ with self.assertRaises(KeyError):
+ ctx[c]
+ self.assertIsNone(ctx.get(c))
+ self.assertEqual(ctx.get(c, 'spam'), 'spam')
+ self.assertNotIn(c, ctx)
+ self.assertEqual(list(ctx.keys()), [])
+
+ t = c.set(1)
+ self.assertEqual(list(ctx.keys()), [c])
+ self.assertEqual(ctx[c], 1)
+
+ c.reset(t)
+ self.assertEqual(list(ctx.keys()), [])
+ with self.assertRaises(KeyError):
+ ctx[c]
+
+ ctx.run(fun)
+
+ @isolated_context
+ def test_context_getset_4(self):
+ c = contextvars.ContextVar('c', default=42)
+ ctx = contextvars.Context()
+
+ tok = ctx.run(c.set, 1)
+
+ with self.assertRaisesRegex(ValueError, 'different Context'):
+ c.reset(tok)
+
+ @isolated_context
+ def test_context_getset_5(self):
+ c = contextvars.ContextVar('c', default=42)
+ c.set([])
+
+ def fun():
+ c.set([])
+ c.get().append(42)
+ self.assertEqual(c.get(), [42])
+
+ contextvars.copy_context().run(fun)
+ self.assertEqual(c.get(), [])
+
+ def test_context_copy_1(self):
+ ctx1 = contextvars.Context()
+ c = contextvars.ContextVar('c', default=42)
+
+ def ctx1_fun():
+ c.set(10)
+
+ ctx2 = ctx1.copy()
+ self.assertEqual(ctx2[c], 10)
+
+ c.set(20)
+ self.assertEqual(ctx1[c], 20)
+ self.assertEqual(ctx2[c], 10)
+
+ ctx2.run(ctx2_fun)
+ self.assertEqual(ctx1[c], 20)
+ self.assertEqual(ctx2[c], 30)
+
+ def ctx2_fun():
+ self.assertEqual(c.get(), 10)
+ c.set(30)
+ self.assertEqual(c.get(), 30)
+
+ ctx1.run(ctx1_fun)
+
+ @isolated_context
+ def test_context_threads_1(self):
+ cvar = contextvars.ContextVar('cvar')
+
+ def sub(num):
+ for i in range(10):
+ cvar.set(num + i)
+ time.sleep(random.uniform(0.001, 0.05))
+ self.assertEqual(cvar.get(), num + i)
+ return num
+
+ tp = concurrent.futures.ThreadPoolExecutor(max_workers=10)
+ try:
+ results = list(tp.map(sub, range(10)))
+ finally:
+ tp.shutdown()
+ self.assertEqual(results, list(range(10)))
+
+
+# HAMT Tests
+
+
+class HashKey:
+ _crasher = None
+
+ def __init__(self, hash, name, *, error_on_eq_to=None):
+ assert hash != -1
+ self.name = name
+ self.hash = hash
+ self.error_on_eq_to = error_on_eq_to
+
+ def __repr__(self):
+ return f'<Key name:{self.name} hash:{self.hash}>'
+
+ def __hash__(self):
+ if self._crasher is not None and self._crasher.error_on_hash:
+ raise HashingError
+
+ return self.hash
+
+ def __eq__(self, other):
+ if not isinstance(other, HashKey):
+ return NotImplemented
+
+ if self._crasher is not None and self._crasher.error_on_eq:
+ raise EqError
+
+ if self.error_on_eq_to is not None and self.error_on_eq_to is other:
+ raise ValueError(f'cannot compare {self!r} to {other!r}')
+ if other.error_on_eq_to is not None and other.error_on_eq_to is self:
+ raise ValueError(f'cannot compare {other!r} to {self!r}')
+
+ return (self.name, self.hash) == (other.name, other.hash)
+
+
+class KeyStr(str):
+ def __hash__(self):
+ if HashKey._crasher is not None and HashKey._crasher.error_on_hash:
+ raise HashingError
+ return super().__hash__()
+
+ def __eq__(self, other):
+ if HashKey._crasher is not None and HashKey._crasher.error_on_eq:
+ raise EqError
+ return super().__eq__(other)
+
+
+class HaskKeyCrasher:
+ def __init__(self, *, error_on_hash=False, error_on_eq=False):
+ self.error_on_hash = error_on_hash
+ self.error_on_eq = error_on_eq
+
+ def __enter__(self):
+ if HashKey._crasher is not None:
+ raise RuntimeError('cannot nest crashers')
+ HashKey._crasher = self
+
+ def __exit__(self, *exc):
+ HashKey._crasher = None
+
+
+class HashingError(Exception):
+ pass
+
+
+class EqError(Exception):
+ pass
+
+
+@unittest.skipIf(hamt is None, '_testcapi lacks "hamt()" function')
+class HamtTest(unittest.TestCase):
+
+ def test_hashkey_helper_1(self):
+ k1 = HashKey(10, 'aaa')
+ k2 = HashKey(10, 'bbb')
+
+ self.assertNotEqual(k1, k2)
+ self.assertEqual(hash(k1), hash(k2))
+
+ d = dict()
+ d[k1] = 'a'
+ d[k2] = 'b'
+
+ self.assertEqual(d[k1], 'a')
+ self.assertEqual(d[k2], 'b')
+
+ def test_hamt_basics_1(self):
+ h = hamt()
+ h = None # NoQA
+
+ def test_hamt_basics_2(self):
+ h = hamt()
+ self.assertEqual(len(h), 0)
+
+ h2 = h.set('a', 'b')
+ self.assertIsNot(h, h2)
+ self.assertEqual(len(h), 0)
+ self.assertEqual(len(h2), 1)
+
+ self.assertIsNone(h.get('a'))
+ self.assertEqual(h.get('a', 42), 42)
+
+ self.assertEqual(h2.get('a'), 'b')
+
+ h3 = h2.set('b', 10)
+ self.assertIsNot(h2, h3)
+ self.assertEqual(len(h), 0)
+ self.assertEqual(len(h2), 1)
+ self.assertEqual(len(h3), 2)
+ self.assertEqual(h3.get('a'), 'b')
+ self.assertEqual(h3.get('b'), 10)
+
+ self.assertIsNone(h.get('b'))
+ self.assertIsNone(h2.get('b'))
+
+ self.assertIsNone(h.get('a'))
+ self.assertEqual(h2.get('a'), 'b')
+
+ h = h2 = h3 = None
+
+ def test_hamt_basics_3(self):
+ h = hamt()
+ o = object()
+ h1 = h.set('1', o)
+ h2 = h1.set('1', o)
+ self.assertIs(h1, h2)
+
+ def test_hamt_basics_4(self):
+ h = hamt()
+ h1 = h.set('key', [])
+ h2 = h1.set('key', [])
+ self.assertIsNot(h1, h2)
+ self.assertEqual(len(h1), 1)
+ self.assertEqual(len(h2), 1)
+ self.assertIsNot(h1.get('key'), h2.get('key'))
+
+ def test_hamt_collision_1(self):
+ k1 = HashKey(10, 'aaa')
+ k2 = HashKey(10, 'bbb')
+ k3 = HashKey(10, 'ccc')
+
+ h = hamt()
+ h2 = h.set(k1, 'a')
+ h3 = h2.set(k2, 'b')
+
+ self.assertEqual(h.get(k1), None)
+ self.assertEqual(h.get(k2), None)
+
+ self.assertEqual(h2.get(k1), 'a')
+ self.assertEqual(h2.get(k2), None)
+
+ self.assertEqual(h3.get(k1), 'a')
+ self.assertEqual(h3.get(k2), 'b')
+
+ h4 = h3.set(k2, 'cc')
+ h5 = h4.set(k3, 'aa')
+
+ self.assertEqual(h3.get(k1), 'a')
+ self.assertEqual(h3.get(k2), 'b')
+ self.assertEqual(h4.get(k1), 'a')
+ self.assertEqual(h4.get(k2), 'cc')
+ self.assertEqual(h4.get(k3), None)
+ self.assertEqual(h5.get(k1), 'a')
+ self.assertEqual(h5.get(k2), 'cc')
+ self.assertEqual(h5.get(k2), 'cc')
+ self.assertEqual(h5.get(k3), 'aa')
+
+ self.assertEqual(len(h), 0)
+ self.assertEqual(len(h2), 1)
+ self.assertEqual(len(h3), 2)
+ self.assertEqual(len(h4), 2)
+ self.assertEqual(len(h5), 3)
+
+ def test_hamt_stress(self):
+ COLLECTION_SIZE = 7000
+ TEST_ITERS_EVERY = 647
+ CRASH_HASH_EVERY = 97
+ CRASH_EQ_EVERY = 11
+ RUN_XTIMES = 3
+
+ for _ in range(RUN_XTIMES):
+ h = hamt()
+ d = dict()
+
+ for i in range(COLLECTION_SIZE):
+ key = KeyStr(i)
+
+ if not (i % CRASH_HASH_EVERY):
+ with HaskKeyCrasher(error_on_hash=True):
+ with self.assertRaises(HashingError):
+ h.set(key, i)
+
+ h = h.set(key, i)
+
+ if not (i % CRASH_EQ_EVERY):
+ with HaskKeyCrasher(error_on_eq=True):
+ with self.assertRaises(EqError):
+ h.get(KeyStr(i)) # really trigger __eq__
+
+ d[key] = i
+ self.assertEqual(len(d), len(h))
+
+ if not (i % TEST_ITERS_EVERY):
+ self.assertEqual(set(h.items()), set(d.items()))
+ self.assertEqual(len(h.items()), len(d.items()))
+
+ self.assertEqual(len(h), COLLECTION_SIZE)
+
+ for key in range(COLLECTION_SIZE):
+ self.assertEqual(h.get(KeyStr(key), 'not found'), key)
+
+ keys_to_delete = list(range(COLLECTION_SIZE))
+ random.shuffle(keys_to_delete)
+ for iter_i, i in enumerate(keys_to_delete):
+ key = KeyStr(i)
+
+ if not (iter_i % CRASH_HASH_EVERY):
+ with HaskKeyCrasher(error_on_hash=True):
+ with self.assertRaises(HashingError):
+ h.delete(key)
+
+ if not (iter_i % CRASH_EQ_EVERY):
+ with HaskKeyCrasher(error_on_eq=True):
+ with self.assertRaises(EqError):
+ h.delete(KeyStr(i))
+
+ h = h.delete(key)
+ self.assertEqual(h.get(key, 'not found'), 'not found')
+ del d[key]
+ self.assertEqual(len(d), len(h))
+
+ if iter_i == COLLECTION_SIZE // 2:
+ hm = h
+ dm = d.copy()
+
+ if not (iter_i % TEST_ITERS_EVERY):
+ self.assertEqual(set(h.keys()), set(d.keys()))
+ self.assertEqual(len(h.keys()), len(d.keys()))
+
+ self.assertEqual(len(d), 0)
+ self.assertEqual(len(h), 0)
+
+ # ============
+
+ for key in dm:
+ self.assertEqual(hm.get(str(key)), dm[key])
+ self.assertEqual(len(dm), len(hm))
+
+ for i, key in enumerate(keys_to_delete):
+ hm = hm.delete(str(key))
+ self.assertEqual(hm.get(str(key), 'not found'), 'not found')
+ dm.pop(str(key), None)
+ self.assertEqual(len(d), len(h))
+
+ if not (i % TEST_ITERS_EVERY):
+ self.assertEqual(set(h.values()), set(d.values()))
+ self.assertEqual(len(h.values()), len(d.values()))
+
+ self.assertEqual(len(d), 0)
+ self.assertEqual(len(h), 0)
+ self.assertEqual(list(h.items()), [])
+
+ def test_hamt_delete_1(self):
+ A = HashKey(100, 'A')
+ B = HashKey(101, 'B')
+ C = HashKey(102, 'C')
+ D = HashKey(103, 'D')
+ E = HashKey(104, 'E')
+ Z = HashKey(-100, 'Z')
+
+ Er = HashKey(103, 'Er', error_on_eq_to=D)
+
+ h = hamt()
+ h = h.set(A, 'a')
+ h = h.set(B, 'b')
+ h = h.set(C, 'c')
+ h = h.set(D, 'd')
+ h = h.set(E, 'e')
+
+ orig_len = len(h)
+
+ # BitmapNode(size=10 bitmap=0b111110000 id=0x10eadc618):
+ # <Key name:A hash:100>: 'a'
+ # <Key name:B hash:101>: 'b'
+ # <Key name:C hash:102>: 'c'
+ # <Key name:D hash:103>: 'd'
+ # <Key name:E hash:104>: 'e'
+
+ h = h.delete(C)
+ self.assertEqual(len(h), orig_len - 1)
+
+ with self.assertRaisesRegex(ValueError, 'cannot compare'):
+ h.delete(Er)
+
+ h = h.delete(D)
+ self.assertEqual(len(h), orig_len - 2)
+
+ h2 = h.delete(Z)
+ self.assertIs(h2, h)
+
+ h = h.delete(A)
+ self.assertEqual(len(h), orig_len - 3)
+
+ self.assertEqual(h.get(A, 42), 42)
+ self.assertEqual(h.get(B), 'b')
+ self.assertEqual(h.get(E), 'e')
+
+ def test_hamt_delete_2(self):
+ A = HashKey(100, 'A')
+ B = HashKey(201001, 'B')
+ C = HashKey(101001, 'C')
+ D = HashKey(103, 'D')
+ E = HashKey(104, 'E')
+ Z = HashKey(-100, 'Z')
+
+ Er = HashKey(201001, 'Er', error_on_eq_to=B)
+
+ h = hamt()
+ h = h.set(A, 'a')
+ h = h.set(B, 'b')
+ h = h.set(C, 'c')
+ h = h.set(D, 'd')
+ h = h.set(E, 'e')
+
+ orig_len = len(h)
+
+ # BitmapNode(size=8 bitmap=0b1110010000):
+ # <Key name:A hash:100>: 'a'
+ # <Key name:D hash:103>: 'd'
+ # <Key name:E hash:104>: 'e'
+ # NULL:
+ # BitmapNode(size=4 bitmap=0b100000000001000000000):
+ # <Key name:B hash:201001>: 'b'
+ # <Key name:C hash:101001>: 'c'
+
+ with self.assertRaisesRegex(ValueError, 'cannot compare'):
+ h.delete(Er)
+
+ h = h.delete(Z)
+ self.assertEqual(len(h), orig_len)
+
+ h = h.delete(C)
+ self.assertEqual(len(h), orig_len - 1)
+
+ h = h.delete(B)
+ self.assertEqual(len(h), orig_len - 2)
+
+ h = h.delete(A)
+ self.assertEqual(len(h), orig_len - 3)
+
+ self.assertEqual(h.get(D), 'd')
+ self.assertEqual(h.get(E), 'e')
+
+ h = h.delete(A)
+ h = h.delete(B)
+ h = h.delete(D)
+ h = h.delete(E)
+ self.assertEqual(len(h), 0)
+
+ def test_hamt_delete_3(self):
+ A = HashKey(100, 'A')
+ B = HashKey(101, 'B')
+ C = HashKey(100100, 'C')
+ D = HashKey(100100, 'D')
+ E = HashKey(104, 'E')
+
+ h = hamt()
+ h = h.set(A, 'a')
+ h = h.set(B, 'b')
+ h = h.set(C, 'c')
+ h = h.set(D, 'd')
+ h = h.set(E, 'e')
+
+ orig_len = len(h)
+
+ # BitmapNode(size=6 bitmap=0b100110000):
+ # NULL:
+ # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
+ # <Key name:A hash:100>: 'a'
+ # NULL:
+ # CollisionNode(size=4 id=0x108572410):
+ # <Key name:C hash:100100>: 'c'
+ # <Key name:D hash:100100>: 'd'
+ # <Key name:B hash:101>: 'b'
+ # <Key name:E hash:104>: 'e'
+
+ h = h.delete(A)
+ self.assertEqual(len(h), orig_len - 1)
+
+ h = h.delete(E)
+ self.assertEqual(len(h), orig_len - 2)
+
+ self.assertEqual(h.get(C), 'c')
+ self.assertEqual(h.get(B), 'b')
+
+ def test_hamt_delete_4(self):
+ A = HashKey(100, 'A')
+ B = HashKey(101, 'B')
+ C = HashKey(100100, 'C')
+ D = HashKey(100100, 'D')
+ E = HashKey(100100, 'E')
+
+ h = hamt()
+ h = h.set(A, 'a')
+ h = h.set(B, 'b')
+ h = h.set(C, 'c')
+ h = h.set(D, 'd')
+ h = h.set(E, 'e')
+
+ orig_len = len(h)
+
+ # BitmapNode(size=4 bitmap=0b110000):
+ # NULL:
+ # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
+ # <Key name:A hash:100>: 'a'
+ # NULL:
+ # CollisionNode(size=6 id=0x10515ef30):
+ # <Key name:C hash:100100>: 'c'
+ # <Key name:D hash:100100>: 'd'
+ # <Key name:E hash:100100>: 'e'
+ # <Key name:B hash:101>: 'b'
+
+ h = h.delete(D)
+ self.assertEqual(len(h), orig_len - 1)
+
+ h = h.delete(E)
+ self.assertEqual(len(h), orig_len - 2)
+
+ h = h.delete(C)
+ self.assertEqual(len(h), orig_len - 3)
+
+ h = h.delete(A)
+ self.assertEqual(len(h), orig_len - 4)
+
+ h = h.delete(B)
+ self.assertEqual(len(h), 0)
+
+ def test_hamt_delete_5(self):
+ h = hamt()
+
+ keys = []
+ for i in range(17):
+ key = HashKey(i, str(i))
+ keys.append(key)
+ h = h.set(key, f'val-{i}')
+
+ collision_key16 = HashKey(16, '18')
+ h = h.set(collision_key16, 'collision')
+
+ # ArrayNode(id=0x10f8b9318):
+ # 0::
+ # BitmapNode(size=2 count=1 bitmap=0b1):
+ # <Key name:0 hash:0>: 'val-0'
+ #
+ # ... 14 more BitmapNodes ...
+ #
+ # 15::
+ # BitmapNode(size=2 count=1 bitmap=0b1):
+ # <Key name:15 hash:15>: 'val-15'
+ #
+ # 16::
+ # BitmapNode(size=2 count=1 bitmap=0b1):
+ # NULL:
+ # CollisionNode(size=4 id=0x10f2f5af8):
+ # <Key name:16 hash:16>: 'val-16'
+ # <Key name:18 hash:16>: 'collision'
+
+ self.assertEqual(len(h), 18)
+
+ h = h.delete(keys[2])
+ self.assertEqual(len(h), 17)
+
+ h = h.delete(collision_key16)
+ self.assertEqual(len(h), 16)
+ h = h.delete(keys[16])
+ self.assertEqual(len(h), 15)
+
+ h = h.delete(keys[1])
+ self.assertEqual(len(h), 14)
+ h = h.delete(keys[1])
+ self.assertEqual(len(h), 14)
+
+ for key in keys:
+ h = h.delete(key)
+ self.assertEqual(len(h), 0)
+
+ def test_hamt_items_1(self):
+ A = HashKey(100, 'A')
+ B = HashKey(201001, 'B')
+ C = HashKey(101001, 'C')
+ D = HashKey(103, 'D')
+ E = HashKey(104, 'E')
+ F = HashKey(110, 'F')
+
+ h = hamt()
+ h = h.set(A, 'a')
+ h = h.set(B, 'b')
+ h = h.set(C, 'c')
+ h = h.set(D, 'd')
+ h = h.set(E, 'e')
+ h = h.set(F, 'f')
+
+ it = h.items()
+ self.assertEqual(
+ set(list(it)),
+ {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
+
+ def test_hamt_items_2(self):
+ A = HashKey(100, 'A')
+ B = HashKey(101, 'B')
+ C = HashKey(100100, 'C')
+ D = HashKey(100100, 'D')
+ E = HashKey(100100, 'E')
+ F = HashKey(110, 'F')
+
+ h = hamt()
+ h = h.set(A, 'a')
+ h = h.set(B, 'b')
+ h = h.set(C, 'c')
+ h = h.set(D, 'd')
+ h = h.set(E, 'e')
+ h = h.set(F, 'f')
+
+ it = h.items()
+ self.assertEqual(
+ set(list(it)),
+ {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
+
+ def test_hamt_keys_1(self):
+ A = HashKey(100, 'A')
+ B = HashKey(101, 'B')
+ C = HashKey(100100, 'C')
+ D = HashKey(100100, 'D')
+ E = HashKey(100100, 'E')
+ F = HashKey(110, 'F')
+
+ h = hamt()
+ h = h.set(A, 'a')
+ h = h.set(B, 'b')
+ h = h.set(C, 'c')
+ h = h.set(D, 'd')
+ h = h.set(E, 'e')
+ h = h.set(F, 'f')
+
+ self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F})
+ self.assertEqual(set(list(h)), {A, B, C, D, E, F})
+
+ def test_hamt_items_3(self):
+ h = hamt()
+ self.assertEqual(len(h.items()), 0)
+ self.assertEqual(list(h.items()), [])
+
+ def test_hamt_eq_1(self):
+ A = HashKey(100, 'A')
+ B = HashKey(101, 'B')
+ C = HashKey(100100, 'C')
+ D = HashKey(100100, 'D')
+ E = HashKey(120, 'E')
+
+ h1 = hamt()
+ h1 = h1.set(A, 'a')
+ h1 = h1.set(B, 'b')
+ h1 = h1.set(C, 'c')
+ h1 = h1.set(D, 'd')
+
+ h2 = hamt()
+ h2 = h2.set(A, 'a')
+
+ self.assertFalse(h1 == h2)
+ self.assertTrue(h1 != h2)
+
+ h2 = h2.set(B, 'b')
+ self.assertFalse(h1 == h2)
+ self.assertTrue(h1 != h2)
+
+ h2 = h2.set(C, 'c')
+ self.assertFalse(h1 == h2)
+ self.assertTrue(h1 != h2)
+
+ h2 = h2.set(D, 'd2')
+ self.assertFalse(h1 == h2)
+ self.assertTrue(h1 != h2)
+
+ h2 = h2.set(D, 'd')
+ self.assertTrue(h1 == h2)
+ self.assertFalse(h1 != h2)
+
+ h2 = h2.set(E, 'e')
+ self.assertFalse(h1 == h2)
+ self.assertTrue(h1 != h2)
+
+ h2 = h2.delete(D)
+ self.assertFalse(h1 == h2)
+ self.assertTrue(h1 != h2)
+
+ h2 = h2.set(E, 'd')
+ self.assertFalse(h1 == h2)
+ self.assertTrue(h1 != h2)
+
+ def test_hamt_eq_2(self):
+ A = HashKey(100, 'A')
+ Er = HashKey(100, 'Er', error_on_eq_to=A)
+
+ h1 = hamt()
+ h1 = h1.set(A, 'a')
+
+ h2 = hamt()
+ h2 = h2.set(Er, 'a')
+
+ with self.assertRaisesRegex(ValueError, 'cannot compare'):
+ h1 == h2
+
+ with self.assertRaisesRegex(ValueError, 'cannot compare'):
+ h1 != h2
+
+ def test_hamt_gc_1(self):
+ A = HashKey(100, 'A')
+
+ h = hamt()
+ h = h.set(0, 0) # empty HAMT node is memoized in hamt.c
+ ref = weakref.ref(h)
+
+ a = []
+ a.append(a)
+ a.append(h)
+ b = []
+ a.append(b)
+ b.append(a)
+ h = h.set(A, b)
+
+ del h, a, b
+
+ gc.collect()
+ gc.collect()
+ gc.collect()
+
+ self.assertIsNone(ref())
+
+ def test_hamt_gc_2(self):
+ A = HashKey(100, 'A')
+ B = HashKey(101, 'B')
+
+ h = hamt()
+ h = h.set(A, 'a')
+ h = h.set(A, h)
+
+ ref = weakref.ref(h)
+ hi = h.items()
+ next(hi)
+
+ del h, hi
+
+ gc.collect()
+ gc.collect()
+ gc.collect()
+
+ self.assertIsNone(ref())
+
+ def test_hamt_in_1(self):
+ A = HashKey(100, 'A')
+ AA = HashKey(100, 'A')
+
+ B = HashKey(101, 'B')
+
+ h = hamt()
+ h = h.set(A, 1)
+
+ self.assertTrue(A in h)
+ self.assertFalse(B in h)
+
+ with self.assertRaises(EqError):
+ with HaskKeyCrasher(error_on_eq=True):
+ AA in h
+
+ with self.assertRaises(HashingError):
+ with HaskKeyCrasher(error_on_hash=True):
+ AA in h
+
+ def test_hamt_getitem_1(self):
+ A = HashKey(100, 'A')
+ AA = HashKey(100, 'A')
+
+ B = HashKey(101, 'B')
+
+ h = hamt()
+ h = h.set(A, 1)
+
+ self.assertEqual(h[A], 1)
+ self.assertEqual(h[AA], 1)
+
+ with self.assertRaises(KeyError):
+ h[B]
+
+ with self.assertRaises(EqError):
+ with HaskKeyCrasher(error_on_eq=True):
+ h[AA]
+
+ with self.assertRaises(HashingError):
+ with HaskKeyCrasher(error_on_hash=True):
+ h[AA]
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Makefile.pre.in b/Makefile.pre.in
index d151267..162006a 100644
--- a/Makefile.pre.in
+++ b/Makefile.pre.in
@@ -354,6 +354,8 @@ PYTHON_OBJS= \
Python/pylifecycle.o \
Python/pymath.o \
Python/pystate.o \
+ Python/context.o \
+ Python/hamt.o \
Python/pythonrun.o \
Python/pytime.o \
Python/bootstrap_hash.o \
@@ -996,6 +998,7 @@ PYTHON_HEADERS= \
$(srcdir)/Include/pymem.h \
$(srcdir)/Include/pyport.h \
$(srcdir)/Include/pystate.h \
+ $(srcdir)/Include/context.h \
$(srcdir)/Include/pystrcmp.h \
$(srcdir)/Include/pystrtod.h \
$(srcdir)/Include/pystrhex.h \
@@ -1023,6 +1026,7 @@ PYTHON_HEADERS= \
$(srcdir)/Include/internal/mem.h \
$(srcdir)/Include/internal/pygetopt.h \
$(srcdir)/Include/internal/pystate.h \
+ $(srcdir)/Include/internal/context.h \
$(srcdir)/Include/internal/warnings.h \
$(DTRACE_HEADERS)
diff --git a/Misc/NEWS.d/next/Core and Builtins/2017-12-28-00-20-42.bpo-32436.H159Jv.rst b/Misc/NEWS.d/next/Core and Builtins/2017-12-28-00-20-42.bpo-32436.H159Jv.rst
new file mode 100644
index 0000000..8586d77
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2017-12-28-00-20-42.bpo-32436.H159Jv.rst
@@ -0,0 +1 @@
+Implement PEP 567
diff --git a/Modules/Setup.dist b/Modules/Setup.dist
index 1f2d56c..239550e 100644
--- a/Modules/Setup.dist
+++ b/Modules/Setup.dist
@@ -176,6 +176,7 @@ _symtable symtablemodule.c
#array arraymodule.c # array objects
#cmath cmathmodule.c _math.c # -lm # complex math library functions
#math mathmodule.c _math.c # -lm # math library functions, e.g. sin()
+#_contextvars _contextvarsmodule.c # Context Variables
#_struct _struct.c # binary structure packing/unpacking
#_weakref _weakref.c # basic weak reference support
#_testcapi _testcapimodule.c # Python C API test module
diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c
index 22ce32c..f77ec99 100644
--- a/Modules/_asynciomodule.c
+++ b/Modules/_asynciomodule.c
@@ -36,6 +36,7 @@ static PyObject *asyncio_task_print_stack_func;
static PyObject *asyncio_task_repr_info_func;
static PyObject *asyncio_InvalidStateError;
static PyObject *asyncio_CancelledError;
+static PyObject *context_kwname;
/* WeakSet containing all alive tasks. */
@@ -59,6 +60,7 @@ typedef enum {
PyObject_HEAD \
PyObject *prefix##_loop; \
PyObject *prefix##_callback0; \
+ PyContext *prefix##_context0; \
PyObject *prefix##_callbacks; \
PyObject *prefix##_exception; \
PyObject *prefix##_result; \
@@ -77,6 +79,7 @@ typedef struct {
FutureObj_HEAD(task)
PyObject *task_fut_waiter;
PyObject *task_coro;
+ PyContext *task_context;
int task_must_cancel;
int task_log_destroy_pending;
} TaskObj;
@@ -336,11 +339,38 @@ get_event_loop(void)
static int
-call_soon(PyObject *loop, PyObject *func, PyObject *arg)
+call_soon(PyObject *loop, PyObject *func, PyObject *arg, PyContext *ctx)
{
PyObject *handle;
- handle = _PyObject_CallMethodIdObjArgs(
- loop, &PyId_call_soon, func, arg, NULL);
+ PyObject *stack[3];
+ Py_ssize_t nargs;
+
+ if (ctx == NULL) {
+ handle = _PyObject_CallMethodIdObjArgs(
+ loop, &PyId_call_soon, func, arg, NULL);
+ }
+ else {
+ /* Use FASTCALL to pass a keyword-only argument to call_soon */
+
+ PyObject *callable = _PyObject_GetAttrId(loop, &PyId_call_soon);
+ if (callable == NULL) {
+ return -1;
+ }
+
+ /* All refs in 'stack' are borrowed. */
+ nargs = 1;
+ stack[0] = func;
+ if (arg != NULL) {
+ stack[1] = arg;
+ nargs++;
+ }
+ stack[nargs] = (PyObject *)ctx;
+
+ handle = _PyObject_FastCallKeywords(
+ callable, stack, nargs, context_kwname);
+ Py_DECREF(callable);
+ }
+
if (handle == NULL) {
return -1;
}
@@ -387,8 +417,11 @@ future_schedule_callbacks(FutureObj *fut)
/* There's a 1st callback */
int ret = call_soon(
- fut->fut_loop, fut->fut_callback0, (PyObject *)fut);
+ fut->fut_loop, fut->fut_callback0,
+ (PyObject *)fut, fut->fut_context0);
+
Py_CLEAR(fut->fut_callback0);
+ Py_CLEAR(fut->fut_context0);
if (ret) {
/* If an error occurs in pure-Python implementation,
all callbacks are cleared. */
@@ -413,9 +446,11 @@ future_schedule_callbacks(FutureObj *fut)
}
for (i = 0; i < len; i++) {
- PyObject *cb = PyList_GET_ITEM(fut->fut_callbacks, i);
+ PyObject *cb_tup = PyList_GET_ITEM(fut->fut_callbacks, i);
+ PyObject *cb = PyTuple_GET_ITEM(cb_tup, 0);
+ PyObject *ctx = PyTuple_GET_ITEM(cb_tup, 1);
- if (call_soon(fut->fut_loop, cb, (PyObject *)fut)) {
+ if (call_soon(fut->fut_loop, cb, (PyObject *)fut, (PyContext *)ctx)) {
/* If an error occurs in pure-Python implementation,
all callbacks are cleared. */
Py_CLEAR(fut->fut_callbacks);
@@ -462,6 +497,7 @@ future_init(FutureObj *fut, PyObject *loop)
}
fut->fut_callback0 = NULL;
+ fut->fut_context0 = NULL;
fut->fut_callbacks = NULL;
return 0;
@@ -566,7 +602,7 @@ future_get_result(FutureObj *fut, PyObject **result)
}
static PyObject *
-future_add_done_callback(FutureObj *fut, PyObject *arg)
+future_add_done_callback(FutureObj *fut, PyObject *arg, PyContext *ctx)
{
if (!future_is_alive(fut)) {
PyErr_SetString(PyExc_RuntimeError, "uninitialized Future object");
@@ -576,7 +612,7 @@ future_add_done_callback(FutureObj *fut, PyObject *arg)
if (fut->fut_state != STATE_PENDING) {
/* The future is done/cancelled, so schedule the callback
right away. */
- if (call_soon(fut->fut_loop, arg, (PyObject*) fut)) {
+ if (call_soon(fut->fut_loop, arg, (PyObject*) fut, ctx)) {
return NULL;
}
}
@@ -602,24 +638,38 @@ future_add_done_callback(FutureObj *fut, PyObject *arg)
with a new list and add the new callback to it.
*/
- if (fut->fut_callbacks != NULL) {
- int err = PyList_Append(fut->fut_callbacks, arg);
- if (err != 0) {
- return NULL;
- }
- }
- else if (fut->fut_callback0 == NULL) {
+ if (fut->fut_callbacks == NULL && fut->fut_callback0 == NULL) {
Py_INCREF(arg);
fut->fut_callback0 = arg;
+ Py_INCREF(ctx);
+ fut->fut_context0 = ctx;
}
else {
- fut->fut_callbacks = PyList_New(1);
- if (fut->fut_callbacks == NULL) {
+ PyObject *tup = PyTuple_New(2);
+ if (tup == NULL) {
return NULL;
}
-
Py_INCREF(arg);
- PyList_SET_ITEM(fut->fut_callbacks, 0, arg);
+ PyTuple_SET_ITEM(tup, 0, arg);
+ Py_INCREF(ctx);
+ PyTuple_SET_ITEM(tup, 1, (PyObject *)ctx);
+
+ if (fut->fut_callbacks != NULL) {
+ int err = PyList_Append(fut->fut_callbacks, tup);
+ if (err) {
+ Py_DECREF(tup);
+ return NULL;
+ }
+ Py_DECREF(tup);
+ }
+ else {
+ fut->fut_callbacks = PyList_New(1);
+ if (fut->fut_callbacks == NULL) {
+ return NULL;
+ }
+
+ PyList_SET_ITEM(fut->fut_callbacks, 0, tup); /* borrow */
+ }
}
}
@@ -676,6 +726,7 @@ FutureObj_clear(FutureObj *fut)
{
Py_CLEAR(fut->fut_loop);
Py_CLEAR(fut->fut_callback0);
+ Py_CLEAR(fut->fut_context0);
Py_CLEAR(fut->fut_callbacks);
Py_CLEAR(fut->fut_result);
Py_CLEAR(fut->fut_exception);
@@ -689,6 +740,7 @@ FutureObj_traverse(FutureObj *fut, visitproc visit, void *arg)
{
Py_VISIT(fut->fut_loop);
Py_VISIT(fut->fut_callback0);
+ Py_VISIT(fut->fut_context0);
Py_VISIT(fut->fut_callbacks);
Py_VISIT(fut->fut_result);
Py_VISIT(fut->fut_exception);
@@ -821,6 +873,8 @@ _asyncio.Future.add_done_callback
fn: object
/
+ *
+ context: object = NULL
Add a callback to be run when the future becomes done.
@@ -830,10 +884,21 @@ scheduled with call_soon.
[clinic start generated code]*/
static PyObject *
-_asyncio_Future_add_done_callback(FutureObj *self, PyObject *fn)
-/*[clinic end generated code: output=819e09629b2ec2b5 input=8f818b39990b027d]*/
+_asyncio_Future_add_done_callback_impl(FutureObj *self, PyObject *fn,
+ PyObject *context)
+/*[clinic end generated code: output=7ce635bbc9554c1e input=15ab0693a96e9533]*/
{
- return future_add_done_callback(self, fn);
+ if (context == NULL) {
+ context = (PyObject *)PyContext_CopyCurrent();
+ if (context == NULL) {
+ return NULL;
+ }
+ PyObject *res = future_add_done_callback(
+ self, fn, (PyContext *)context);
+ Py_DECREF(context);
+ return res;
+ }
+ return future_add_done_callback(self, fn, (PyContext *)context);
}
/*[clinic input]
@@ -865,6 +930,7 @@ _asyncio_Future_remove_done_callback(FutureObj *self, PyObject *fn)
if (cmp == 1) {
/* callback0 == fn */
Py_CLEAR(self->fut_callback0);
+ Py_CLEAR(self->fut_context0);
cleared_callback0 = 1;
}
}
@@ -880,8 +946,9 @@ _asyncio_Future_remove_done_callback(FutureObj *self, PyObject *fn)
}
if (len == 1) {
+ PyObject *cb_tup = PyList_GET_ITEM(self->fut_callbacks, 0);
int cmp = PyObject_RichCompareBool(
- fn, PyList_GET_ITEM(self->fut_callbacks, 0), Py_EQ);
+ fn, PyTuple_GET_ITEM(cb_tup, 0), Py_EQ);
if (cmp == -1) {
return NULL;
}
@@ -903,7 +970,7 @@ _asyncio_Future_remove_done_callback(FutureObj *self, PyObject *fn)
int ret;
PyObject *item = PyList_GET_ITEM(self->fut_callbacks, i);
Py_INCREF(item);
- ret = PyObject_RichCompareBool(fn, item, Py_EQ);
+ ret = PyObject_RichCompareBool(fn, PyTuple_GET_ITEM(item, 0), Py_EQ);
if (ret == 0) {
if (j < len) {
PyList_SET_ITEM(newlist, j, item);
@@ -1081,47 +1148,49 @@ static PyObject *
FutureObj_get_callbacks(FutureObj *fut)
{
Py_ssize_t i;
- Py_ssize_t len;
- PyObject *new_list;
ENSURE_FUTURE_ALIVE(fut)
- if (fut->fut_callbacks == NULL) {
- if (fut->fut_callback0 == NULL) {
+ if (fut->fut_callback0 == NULL) {
+ if (fut->fut_callbacks == NULL) {
Py_RETURN_NONE;
}
- else {
- new_list = PyList_New(1);
- if (new_list == NULL) {
- return NULL;
- }
- Py_INCREF(fut->fut_callback0);
- PyList_SET_ITEM(new_list, 0, fut->fut_callback0);
- return new_list;
- }
- }
- assert(fut->fut_callbacks != NULL);
-
- if (fut->fut_callback0 == NULL) {
Py_INCREF(fut->fut_callbacks);
return fut->fut_callbacks;
}
- assert(fut->fut_callback0 != NULL);
+ Py_ssize_t len = 1;
+ if (fut->fut_callbacks != NULL) {
+ len += PyList_GET_SIZE(fut->fut_callbacks);
+ }
- len = PyList_GET_SIZE(fut->fut_callbacks);
- new_list = PyList_New(len + 1);
+
+ PyObject *new_list = PyList_New(len);
if (new_list == NULL) {
return NULL;
}
+ PyObject *tup0 = PyTuple_New(2);
+ if (tup0 == NULL) {
+ Py_DECREF(new_list);
+ return NULL;
+ }
+
Py_INCREF(fut->fut_callback0);
- PyList_SET_ITEM(new_list, 0, fut->fut_callback0);
- for (i = 0; i < len; i++) {
- PyObject *cb = PyList_GET_ITEM(fut->fut_callbacks, i);
- Py_INCREF(cb);
- PyList_SET_ITEM(new_list, i + 1, cb);
+ PyTuple_SET_ITEM(tup0, 0, fut->fut_callback0);
+ assert(fut->fut_context0 != NULL);
+ Py_INCREF(fut->fut_context0);
+ PyTuple_SET_ITEM(tup0, 1, (PyObject *)fut->fut_context0);
+
+ PyList_SET_ITEM(new_list, 0, tup0);
+
+ if (fut->fut_callbacks != NULL) {
+ for (i = 0; i < PyList_GET_SIZE(fut->fut_callbacks); i++) {
+ PyObject *cb = PyList_GET_ITEM(fut->fut_callbacks, i);
+ Py_INCREF(cb);
+ PyList_SET_ITEM(new_list, i + 1, cb);
+ }
}
return new_list;
@@ -1912,6 +1981,11 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop)
return -1;
}
+ self->task_context = PyContext_CopyCurrent();
+ if (self->task_context == NULL) {
+ return -1;
+ }
+
self->task_fut_waiter = NULL;
self->task_must_cancel = 0;
self->task_log_destroy_pending = 1;
@@ -1928,6 +2002,7 @@ static int
TaskObj_clear(TaskObj *task)
{
(void)FutureObj_clear((FutureObj*) task);
+ Py_CLEAR(task->task_context);
Py_CLEAR(task->task_coro);
Py_CLEAR(task->task_fut_waiter);
return 0;
@@ -1936,6 +2011,7 @@ TaskObj_clear(TaskObj *task)
static int
TaskObj_traverse(TaskObj *task, visitproc visit, void *arg)
{
+ Py_VISIT(task->task_context);
Py_VISIT(task->task_coro);
Py_VISIT(task->task_fut_waiter);
(void)FutureObj_traverse((FutureObj*) task, visit, arg);
@@ -2451,7 +2527,7 @@ task_call_step_soon(TaskObj *task, PyObject *arg)
return -1;
}
- int ret = call_soon(task->task_loop, cb, NULL);
+ int ret = call_soon(task->task_loop, cb, NULL, task->task_context);
Py_DECREF(cb);
return ret;
}
@@ -2650,7 +2726,8 @@ set_exception:
if (wrapper == NULL) {
goto fail;
}
- res = future_add_done_callback((FutureObj*)result, wrapper);
+ res = future_add_done_callback(
+ (FutureObj*)result, wrapper, task->task_context);
Py_DECREF(wrapper);
if (res == NULL) {
goto fail;
@@ -2724,14 +2801,23 @@ set_exception:
goto fail;
}
- /* result.add_done_callback(task._wakeup) */
wrapper = TaskWakeupMethWrapper_new(task);
if (wrapper == NULL) {
goto fail;
}
- res = _PyObject_CallMethodIdObjArgs(result,
- &PyId_add_done_callback,
- wrapper, NULL);
+
+ /* result.add_done_callback(task._wakeup) */
+ PyObject *add_cb = _PyObject_GetAttrId(
+ result, &PyId_add_done_callback);
+ if (add_cb == NULL) {
+ goto fail;
+ }
+ PyObject *stack[2];
+ stack[0] = wrapper;
+ stack[1] = (PyObject *)task->task_context;
+ res = _PyObject_FastCallKeywords(
+ add_cb, stack, 1, context_kwname);
+ Py_DECREF(add_cb);
Py_DECREF(wrapper);
if (res == NULL) {
goto fail;
@@ -3141,6 +3227,8 @@ module_free(void *m)
Py_CLEAR(current_tasks);
Py_CLEAR(iscoroutine_typecache);
+ Py_CLEAR(context_kwname);
+
module_free_freelists();
}
@@ -3164,6 +3252,17 @@ module_init(void)
goto fail;
}
+
+ context_kwname = PyTuple_New(1);
+ if (context_kwname == NULL) {
+ goto fail;
+ }
+ PyObject *context_str = PyUnicode_FromString("context");
+ if (context_str == NULL) {
+ goto fail;
+ }
+ PyTuple_SET_ITEM(context_kwname, 0, context_str);
+
#define WITH_MOD(NAME) \
Py_CLEAR(module); \
module = PyImport_ImportModule(NAME); \
diff --git a/Modules/_contextvarsmodule.c b/Modules/_contextvarsmodule.c
new file mode 100644
index 0000000..b7d112d
--- /dev/null
+++ b/Modules/_contextvarsmodule.c
@@ -0,0 +1,75 @@
+#include "Python.h"
+
+#include "clinic/_contextvarsmodule.c.h"
+
+/*[clinic input]
+module _contextvars
+[clinic start generated code]*/
+/*[clinic end generated code: output=da39a3ee5e6b4b0d input=a0955718c8b8cea6]*/
+
+
+/*[clinic input]
+_contextvars.copy_context
+[clinic start generated code]*/
+
+static PyObject *
+_contextvars_copy_context_impl(PyObject *module)
+/*[clinic end generated code: output=1fcd5da7225c4fa9 input=89bb9ae485888440]*/
+{
+ return (PyObject *)PyContext_CopyCurrent();
+}
+
+
+PyDoc_STRVAR(module_doc, "Context Variables");
+
+static PyMethodDef _contextvars_methods[] = {
+ _CONTEXTVARS_COPY_CONTEXT_METHODDEF
+ {NULL, NULL}
+};
+
+static struct PyModuleDef _contextvarsmodule = {
+ PyModuleDef_HEAD_INIT, /* m_base */
+ "_contextvars", /* m_name */
+ module_doc, /* m_doc */
+ -1, /* m_size */
+ _contextvars_methods, /* m_methods */
+ NULL, /* m_slots */
+ NULL, /* m_traverse */
+ NULL, /* m_clear */
+ NULL, /* m_free */
+};
+
+PyMODINIT_FUNC
+PyInit__contextvars(void)
+{
+ PyObject *m = PyModule_Create(&_contextvarsmodule);
+ if (m == NULL) {
+ return NULL;
+ }
+
+ Py_INCREF(&PyContext_Type);
+ if (PyModule_AddObject(m, "Context",
+ (PyObject *)&PyContext_Type) < 0)
+ {
+ Py_DECREF(&PyContext_Type);
+ return NULL;
+ }
+
+ Py_INCREF(&PyContextVar_Type);
+ if (PyModule_AddObject(m, "ContextVar",
+ (PyObject *)&PyContextVar_Type) < 0)
+ {
+ Py_DECREF(&PyContextVar_Type);
+ return NULL;
+ }
+
+ Py_INCREF(&PyContextToken_Type);
+ if (PyModule_AddObject(m, "Token",
+ (PyObject *)&PyContextToken_Type) < 0)
+ {
+ Py_DECREF(&PyContextToken_Type);
+ return NULL;
+ }
+
+ return m;
+}
diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c
index 0d6bf45..e3be7d3 100644
--- a/Modules/_testcapimodule.c
+++ b/Modules/_testcapimodule.c
@@ -4438,6 +4438,13 @@ test_pythread_tss_key_state(PyObject *self, PyObject *args)
}
+static PyObject*
+new_hamt(PyObject *self, PyObject *args)
+{
+ return _PyContext_NewHamtForTests();
+}
+
+
static PyMethodDef TestMethods[] = {
{"raise_exception", raise_exception, METH_VARARGS},
{"raise_memoryerror", (PyCFunction)raise_memoryerror, METH_NOARGS},
@@ -4655,6 +4662,7 @@ static PyMethodDef TestMethods[] = {
{"get_mapping_values", get_mapping_values, METH_O},
{"get_mapping_items", get_mapping_items, METH_O},
{"test_pythread_tss_key_state", test_pythread_tss_key_state, METH_VARARGS},
+ {"hamt", new_hamt, METH_NOARGS},
{NULL, NULL} /* sentinel */
};
diff --git a/Modules/clinic/_asynciomodule.c.h b/Modules/clinic/_asynciomodule.c.h
index f2e0f40..9fc9d6b 100644
--- a/Modules/clinic/_asynciomodule.c.h
+++ b/Modules/clinic/_asynciomodule.c.h
@@ -110,7 +110,7 @@ PyDoc_STRVAR(_asyncio_Future_set_exception__doc__,
{"set_exception", (PyCFunction)_asyncio_Future_set_exception, METH_O, _asyncio_Future_set_exception__doc__},
PyDoc_STRVAR(_asyncio_Future_add_done_callback__doc__,
-"add_done_callback($self, fn, /)\n"
+"add_done_callback($self, fn, /, *, context=None)\n"
"--\n"
"\n"
"Add a callback to be run when the future becomes done.\n"
@@ -120,7 +120,30 @@ PyDoc_STRVAR(_asyncio_Future_add_done_callback__doc__,
"scheduled with call_soon.");
#define _ASYNCIO_FUTURE_ADD_DONE_CALLBACK_METHODDEF \
- {"add_done_callback", (PyCFunction)_asyncio_Future_add_done_callback, METH_O, _asyncio_Future_add_done_callback__doc__},
+ {"add_done_callback", (PyCFunction)_asyncio_Future_add_done_callback, METH_FASTCALL|METH_KEYWORDS, _asyncio_Future_add_done_callback__doc__},
+
+static PyObject *
+_asyncio_Future_add_done_callback_impl(FutureObj *self, PyObject *fn,
+ PyObject *context);
+
+static PyObject *
+_asyncio_Future_add_done_callback(FutureObj *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
+{
+ PyObject *return_value = NULL;
+ static const char * const _keywords[] = {"", "context", NULL};
+ static _PyArg_Parser _parser = {"O|$O:add_done_callback", _keywords, 0};
+ PyObject *fn;
+ PyObject *context = NULL;
+
+ if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &_parser,
+ &fn, &context)) {
+ goto exit;
+ }
+ return_value = _asyncio_Future_add_done_callback_impl(self, fn, context);
+
+exit:
+ return return_value;
+}
PyDoc_STRVAR(_asyncio_Future_remove_done_callback__doc__,
"remove_done_callback($self, fn, /)\n"
@@ -763,4 +786,4 @@ _asyncio__leave_task(PyObject *module, PyObject *const *args, Py_ssize_t nargs,
exit:
return return_value;
}
-/*[clinic end generated code: output=616e814431893dcc input=a9049054013a1b77]*/
+/*[clinic end generated code: output=bcbaf1b2480f4aa9 input=a9049054013a1b77]*/
diff --git a/Modules/clinic/_contextvarsmodule.c.h b/Modules/clinic/_contextvarsmodule.c.h
new file mode 100644
index 0000000..b1885e4
--- /dev/null
+++ b/Modules/clinic/_contextvarsmodule.c.h
@@ -0,0 +1,21 @@
+/*[clinic input]
+preserve
+[clinic start generated code]*/
+
+PyDoc_STRVAR(_contextvars_copy_context__doc__,
+"copy_context($module, /)\n"
+"--\n"
+"\n");
+
+#define _CONTEXTVARS_COPY_CONTEXT_METHODDEF \
+ {"copy_context", (PyCFunction)_contextvars_copy_context, METH_NOARGS, _contextvars_copy_context__doc__},
+
+static PyObject *
+_contextvars_copy_context_impl(PyObject *module);
+
+static PyObject *
+_contextvars_copy_context(PyObject *module, PyObject *Py_UNUSED(ignored))
+{
+ return _contextvars_copy_context_impl(module);
+}
+/*[clinic end generated code: output=26e07024451baf52 input=a9049054013a1b77]*/
diff --git a/Modules/gcmodule.c b/Modules/gcmodule.c
index ea3c294..8ba1093 100644
--- a/Modules/gcmodule.c
+++ b/Modules/gcmodule.c
@@ -24,6 +24,7 @@
*/
#include "Python.h"
+#include "internal/context.h"
#include "internal/mem.h"
#include "internal/pystate.h"
#include "frameobject.h" /* for PyFrame_ClearFreeList */
@@ -790,6 +791,7 @@ clear_freelists(void)
(void)PyDict_ClearFreeList();
(void)PySet_ClearFreeList();
(void)PyAsyncGen_ClearFreeLists();
+ (void)PyContext_ClearFreeList();
}
/* This is the main function. Read this to understand how the
diff --git a/Objects/object.c b/Objects/object.c
index 62d7fbe..8cec6e2 100644
--- a/Objects/object.c
+++ b/Objects/object.c
@@ -3,6 +3,7 @@
#include "Python.h"
#include "internal/pystate.h"
+#include "internal/context.h"
#include "frameobject.h"
#ifdef __cplusplus
diff --git a/PCbuild/_contextvars.vcxproj b/PCbuild/_contextvars.vcxproj
new file mode 100644
index 0000000..7418e86
--- /dev/null
+++ b/PCbuild/_contextvars.vcxproj
@@ -0,0 +1,77 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project DefaultTargets="Build" ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup Label="ProjectConfigurations">
+ <ProjectConfiguration Include="Debug|Win32">
+ <Configuration>Debug</Configuration>
+ <Platform>Win32</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Debug|x64">
+ <Configuration>Debug</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="PGInstrument|Win32">
+ <Configuration>PGInstrument</Configuration>
+ <Platform>Win32</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="PGInstrument|x64">
+ <Configuration>PGInstrument</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="PGUpdate|Win32">
+ <Configuration>PGUpdate</Configuration>
+ <Platform>Win32</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="PGUpdate|x64">
+ <Configuration>PGUpdate</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|Win32">
+ <Configuration>Release</Configuration>
+ <Platform>Win32</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|x64">
+ <Configuration>Release</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ </ItemGroup>
+ <PropertyGroup Label="Globals">
+ <ProjectGuid>{B8BF1D81-09DC-42D4-B406-4F868B33A89E}</ProjectGuid>
+ <RootNamespace>_contextvars</RootNamespace>
+ <Keyword>Win32Proj</Keyword>
+ </PropertyGroup>
+ <Import Project="python.props" />
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
+ <PropertyGroup Label="Configuration">
+ <ConfigurationType>DynamicLibrary</ConfigurationType>
+ <CharacterSet>NotSet</CharacterSet>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
+ <PropertyGroup>
+ <TargetExt>.pyd</TargetExt>
+ </PropertyGroup>
+ <ImportGroup Label="ExtensionSettings">
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ <Import Project="pyproject.props" />
+ </ImportGroup>
+ <PropertyGroup Label="UserMacros" />
+ <PropertyGroup>
+ <_ProjectFileVersion>10.0.30319.1</_ProjectFileVersion>
+ </PropertyGroup>
+ <ItemGroup>
+ <ClCompile Include="..\Modules\_contextvarsmodule.c" />
+ </ItemGroup>
+ <ItemGroup>
+ <ResourceCompile Include="..\PC\python_nt.rc" />
+ </ItemGroup>
+ <ItemGroup>
+ <ProjectReference Include="pythoncore.vcxproj">
+ <Project>{cf7ac3d1-e2df-41d2-bea6-1e2556cdea26}</Project>
+ <ReferenceOutputAssembly>false</ReferenceOutputAssembly>
+ </ProjectReference>
+ </ItemGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
+ <ImportGroup Label="ExtensionTargets">
+ </ImportGroup>
+</Project>
diff --git a/PCbuild/_contextvars.vcxproj.filters b/PCbuild/_contextvars.vcxproj.filters
new file mode 100644
index 0000000..b3002b7
--- /dev/null
+++ b/PCbuild/_contextvars.vcxproj.filters
@@ -0,0 +1,16 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup>
+ <ResourceCompile Include="..\PC\python_nt.rc" />
+ </ItemGroup>
+ <ItemGroup>
+ <Filter Include="Source Files">
+ <UniqueIdentifier>{7CBD8910-233D-4E9A-9164-9BA66C1F0E6D}</UniqueIdentifier>
+ </Filter>
+ </ItemGroup>
+ <ItemGroup>
+ <ClCompile Include="..\Modules\_contextvarsmodule.c">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ </ItemGroup>
+</Project>
diff --git a/PCbuild/_decimal.vcxproj b/PCbuild/_decimal.vcxproj
index b14f310..df9f600 100644
--- a/PCbuild/_decimal.vcxproj
+++ b/PCbuild/_decimal.vcxproj
@@ -121,4 +121,4 @@
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
-</Project> \ No newline at end of file
+</Project>
diff --git a/PCbuild/pcbuild.proj b/PCbuild/pcbuild.proj
index 848b3b2..5e34195 100644
--- a/PCbuild/pcbuild.proj
+++ b/PCbuild/pcbuild.proj
@@ -49,7 +49,7 @@
<!-- pyshellext.dll -->
<Projects Include="pyshellext.vcxproj" />
<!-- Extension modules -->
- <ExtensionModules Include="_asyncio;_ctypes;_decimal;_distutils_findvs;_elementtree;_msi;_multiprocessing;_overlapped;pyexpat;_queue;select;unicodedata;winsound" />
+ <ExtensionModules Include="_asyncio;_contextvars;_ctypes;_decimal;_distutils_findvs;_elementtree;_msi;_multiprocessing;_overlapped;pyexpat;_queue;select;unicodedata;winsound" />
<!-- Extension modules that require external sources -->
<ExternalModules Include="_bz2;_lzma;_sqlite3" />
<!-- _ssl will build _socket as well, which may cause conflicts in parallel builds -->
diff --git a/PCbuild/pythoncore.vcxproj b/PCbuild/pythoncore.vcxproj
index bf2ce66..fbcd051 100644
--- a/PCbuild/pythoncore.vcxproj
+++ b/PCbuild/pythoncore.vcxproj
@@ -94,6 +94,7 @@
<ClInclude Include="..\Include\codecs.h" />
<ClInclude Include="..\Include\compile.h" />
<ClInclude Include="..\Include\complexobject.h" />
+ <ClInclude Include="..\Include\context.h" />
<ClInclude Include="..\Include\datetime.h" />
<ClInclude Include="..\Include\descrobject.h" />
<ClInclude Include="..\Include\dictobject.h" />
@@ -112,7 +113,9 @@
<ClInclude Include="..\Include\import.h" />
<ClInclude Include="..\Include\internal\ceval.h" />
<ClInclude Include="..\Include\internal\condvar.h" />
+ <ClInclude Include="..\Include\internal\context.h" />
<ClInclude Include="..\Include\internal\gil.h" />
+ <ClInclude Include="..\Include\internal\hamt.h" />
<ClInclude Include="..\Include\internal\mem.h" />
<ClInclude Include="..\Include\internal\pystate.h" />
<ClInclude Include="..\Include\internal\warnings.h" />
@@ -232,6 +235,7 @@
<ClCompile Include="..\Modules\_blake2\blake2s_impl.c" />
<ClCompile Include="..\Modules\_codecsmodule.c" />
<ClCompile Include="..\Modules\_collectionsmodule.c" />
+ <ClCompile Include="..\Modules\_contextvarsmodule.c" />
<ClCompile Include="..\Modules\_csv.c" />
<ClCompile Include="..\Modules\_functoolsmodule.c" />
<ClCompile Include="..\Modules\_heapqmodule.c" />
@@ -359,6 +363,7 @@
<ClCompile Include="..\Python\ceval.c" />
<ClCompile Include="..\Python\codecs.c" />
<ClCompile Include="..\Python\compile.c" />
+ <ClCompile Include="..\Python\context.c" />
<ClCompile Include="..\Python\dynamic_annotations.c" />
<ClCompile Include="..\Python\dynload_win.c" />
<ClCompile Include="..\Python\errors.c" />
@@ -373,6 +378,7 @@
<ClCompile Include="..\Python\getplatform.c" />
<ClCompile Include="..\Python\getversion.c" />
<ClCompile Include="..\Python\graminit.c" />
+ <ClCompile Include="..\Python\hamt.c" />
<ClCompile Include="..\Python\import.c" />
<ClCompile Include="..\Python\importdl.c" />
<ClCompile Include="..\Python\marshal.c" />
diff --git a/PCbuild/pythoncore.vcxproj.filters b/PCbuild/pythoncore.vcxproj.filters
index 13600cb..a10686c 100644
--- a/PCbuild/pythoncore.vcxproj.filters
+++ b/PCbuild/pythoncore.vcxproj.filters
@@ -81,6 +81,9 @@
<ClInclude Include="..\Include\complexobject.h">
<Filter>Include</Filter>
</ClInclude>
+ <ClInclude Include="..\Include\context.h">
+ <Filter>Include</Filter>
+ </ClInclude>
<ClInclude Include="..\Include\datetime.h">
<Filter>Include</Filter>
</ClInclude>
@@ -135,9 +138,15 @@
<ClInclude Include="..\Include\internal\condvar.h">
<Filter>Include</Filter>
</ClInclude>
+ <ClInclude Include="..\Include\internal\context.h">
+ <Filter>Include</Filter>
+ </ClInclude>
<ClInclude Include="..\Include\internal\gil.h">
<Filter>Include</Filter>
</ClInclude>
+ <ClInclude Include="..\Include\internal\hamt.h">
+ <Filter>Include</Filter>
+ </ClInclude>
<ClInclude Include="..\Include\internal\mem.h">
<Filter>Include</Filter>
</ClInclude>
@@ -842,6 +851,9 @@
<ClCompile Include="..\Python\compile.c">
<Filter>Python</Filter>
</ClCompile>
+ <ClCompile Include="..\Python\context.h">
+ <Filter>Python</Filter>
+ </ClCompile>
<ClCompile Include="..\Python\dynamic_annotations.c">
<Filter>Python</Filter>
</ClCompile>
@@ -884,6 +896,9 @@
<ClCompile Include="..\Python\graminit.c">
<Filter>Python</Filter>
</ClCompile>
+ <ClCompile Include="..\Python\hamt.h">
+ <Filter>Python</Filter>
+ </ClCompile>
<ClCompile Include="..\Python\import.c">
<Filter>Python</Filter>
</ClCompile>
@@ -998,6 +1013,9 @@
<ClCompile Include="..\Modules\_asynciomodule.c">
<Filter>Modules</Filter>
</ClCompile>
+ <ClCompile Include="..\Modules\_contextvarsmodule.c">
+ <Filter>Modules</Filter>
+ </ClCompile>
<ClCompile Include="$(zlibDir)\adler32.c">
<Filter>Modules\zlib</Filter>
</ClCompile>
diff --git a/Python/clinic/context.c.h b/Python/clinic/context.c.h
new file mode 100644
index 0000000..dcf4c21
--- /dev/null
+++ b/Python/clinic/context.c.h
@@ -0,0 +1,146 @@
+/*[clinic input]
+preserve
+[clinic start generated code]*/
+
+PyDoc_STRVAR(_contextvars_Context_get__doc__,
+"get($self, key, default=None, /)\n"
+"--\n"
+"\n");
+
+#define _CONTEXTVARS_CONTEXT_GET_METHODDEF \
+ {"get", (PyCFunction)_contextvars_Context_get, METH_FASTCALL, _contextvars_Context_get__doc__},
+
+static PyObject *
+_contextvars_Context_get_impl(PyContext *self, PyObject *key,
+ PyObject *default_value);
+
+static PyObject *
+_contextvars_Context_get(PyContext *self, PyObject *const *args, Py_ssize_t nargs)
+{
+ PyObject *return_value = NULL;
+ PyObject *key;
+ PyObject *default_value = Py_None;
+
+ if (!_PyArg_UnpackStack(args, nargs, "get",
+ 1, 2,
+ &key, &default_value)) {
+ goto exit;
+ }
+ return_value = _contextvars_Context_get_impl(self, key, default_value);
+
+exit:
+ return return_value;
+}
+
+PyDoc_STRVAR(_contextvars_Context_items__doc__,
+"items($self, /)\n"
+"--\n"
+"\n");
+
+#define _CONTEXTVARS_CONTEXT_ITEMS_METHODDEF \
+ {"items", (PyCFunction)_contextvars_Context_items, METH_NOARGS, _contextvars_Context_items__doc__},
+
+static PyObject *
+_contextvars_Context_items_impl(PyContext *self);
+
+static PyObject *
+_contextvars_Context_items(PyContext *self, PyObject *Py_UNUSED(ignored))
+{
+ return _contextvars_Context_items_impl(self);
+}
+
+PyDoc_STRVAR(_contextvars_Context_keys__doc__,
+"keys($self, /)\n"
+"--\n"
+"\n");
+
+#define _CONTEXTVARS_CONTEXT_KEYS_METHODDEF \
+ {"keys", (PyCFunction)_contextvars_Context_keys, METH_NOARGS, _contextvars_Context_keys__doc__},
+
+static PyObject *
+_contextvars_Context_keys_impl(PyContext *self);
+
+static PyObject *
+_contextvars_Context_keys(PyContext *self, PyObject *Py_UNUSED(ignored))
+{
+ return _contextvars_Context_keys_impl(self);
+}
+
+PyDoc_STRVAR(_contextvars_Context_values__doc__,
+"values($self, /)\n"
+"--\n"
+"\n");
+
+#define _CONTEXTVARS_CONTEXT_VALUES_METHODDEF \
+ {"values", (PyCFunction)_contextvars_Context_values, METH_NOARGS, _contextvars_Context_values__doc__},
+
+static PyObject *
+_contextvars_Context_values_impl(PyContext *self);
+
+static PyObject *
+_contextvars_Context_values(PyContext *self, PyObject *Py_UNUSED(ignored))
+{
+ return _contextvars_Context_values_impl(self);
+}
+
+PyDoc_STRVAR(_contextvars_Context_copy__doc__,
+"copy($self, /)\n"
+"--\n"
+"\n");
+
+#define _CONTEXTVARS_CONTEXT_COPY_METHODDEF \
+ {"copy", (PyCFunction)_contextvars_Context_copy, METH_NOARGS, _contextvars_Context_copy__doc__},
+
+static PyObject *
+_contextvars_Context_copy_impl(PyContext *self);
+
+static PyObject *
+_contextvars_Context_copy(PyContext *self, PyObject *Py_UNUSED(ignored))
+{
+ return _contextvars_Context_copy_impl(self);
+}
+
+PyDoc_STRVAR(_contextvars_ContextVar_get__doc__,
+"get($self, default=None, /)\n"
+"--\n"
+"\n");
+
+#define _CONTEXTVARS_CONTEXTVAR_GET_METHODDEF \
+ {"get", (PyCFunction)_contextvars_ContextVar_get, METH_FASTCALL, _contextvars_ContextVar_get__doc__},
+
+static PyObject *
+_contextvars_ContextVar_get_impl(PyContextVar *self, PyObject *default_value);
+
+static PyObject *
+_contextvars_ContextVar_get(PyContextVar *self, PyObject *const *args, Py_ssize_t nargs)
+{
+ PyObject *return_value = NULL;
+ PyObject *default_value = NULL;
+
+ if (!_PyArg_UnpackStack(args, nargs, "get",
+ 0, 1,
+ &default_value)) {
+ goto exit;
+ }
+ return_value = _contextvars_ContextVar_get_impl(self, default_value);
+
+exit:
+ return return_value;
+}
+
+PyDoc_STRVAR(_contextvars_ContextVar_set__doc__,
+"set($self, value, /)\n"
+"--\n"
+"\n");
+
+#define _CONTEXTVARS_CONTEXTVAR_SET_METHODDEF \
+ {"set", (PyCFunction)_contextvars_ContextVar_set, METH_O, _contextvars_ContextVar_set__doc__},
+
+PyDoc_STRVAR(_contextvars_ContextVar_reset__doc__,
+"reset($self, token, /)\n"
+"--\n"
+"\n");
+
+#define _CONTEXTVARS_CONTEXTVAR_RESET_METHODDEF \
+ {"reset", (PyCFunction)_contextvars_ContextVar_reset, METH_O, _contextvars_ContextVar_reset__doc__},
+/*[clinic end generated code: output=d9a675e3a52a14fc input=a9049054013a1b77]*/
diff --git a/Python/context.c b/Python/context.c
new file mode 100644
index 0000000..2f1d0f5
--- /dev/null
+++ b/Python/context.c
@@ -0,0 +1,1220 @@
+#include "Python.h"
+
+#include "structmember.h"
+#include "internal/pystate.h"
+#include "internal/context.h"
+#include "internal/hamt.h"
+
+
+#define CONTEXT_FREELIST_MAXLEN 255
+static PyContext *ctx_freelist = NULL;
+static Py_ssize_t ctx_freelist_len = 0;
+
+
+#include "clinic/context.c.h"
+/*[clinic input]
+module _contextvars
+[clinic start generated code]*/
+/*[clinic end generated code: output=da39a3ee5e6b4b0d input=a0955718c8b8cea6]*/
+
+
+/////////////////////////// Context API
+
+
+static PyContext *
+context_new_empty(void);
+
+static PyContext *
+context_new_from_vars(PyHamtObject *vars);
+
+static inline PyContext *
+context_get(void);
+
+static PyContextToken *
+token_new(PyContext *ctx, PyContextVar *var, PyObject *val);
+
+static PyContextVar *
+contextvar_new(PyObject *name, PyObject *def);
+
+static int
+contextvar_set(PyContextVar *var, PyObject *val);
+
+static int
+contextvar_del(PyContextVar *var);
+
+
+PyObject *
+_PyContext_NewHamtForTests(void)
+{
+ return (PyObject *)_PyHamt_New();
+}
+
+
+PyContext *
+PyContext_New(void)
+{
+ return context_new_empty();
+}
+
+
+PyContext *
+PyContext_Copy(PyContext * ctx)
+{
+ return context_new_from_vars(ctx->ctx_vars);
+}
+
+
+PyContext *
+PyContext_CopyCurrent(void)
+{
+ PyContext *ctx = context_get();
+ if (ctx == NULL) {
+ return NULL;
+ }
+
+ return context_new_from_vars(ctx->ctx_vars);
+}
+
+
+int
+PyContext_Enter(PyContext *ctx)
+{
+ if (ctx->ctx_entered) {
+ PyErr_Format(PyExc_RuntimeError,
+ "cannot enter context: %R is already entered", ctx);
+ return -1;
+ }
+
+ PyThreadState *ts = PyThreadState_Get();
+
+ ctx->ctx_prev = (PyContext *)ts->context; /* borrow */
+ ctx->ctx_entered = 1;
+
+ Py_INCREF(ctx);
+ ts->context = (PyObject *)ctx;
+ ts->context_ver++;
+
+ return 0;
+}
+
+
+int
+PyContext_Exit(PyContext *ctx)
+{
+ if (!ctx->ctx_entered) {
+ PyErr_Format(PyExc_RuntimeError,
+ "cannot exit context: %R has not been entered", ctx);
+ return -1;
+ }
+
+ PyThreadState *ts = PyThreadState_Get();
+
+ if (ts->context != (PyObject *)ctx) {
+ /* Can only happen if someone misuses the C API */
+ PyErr_SetString(PyExc_RuntimeError,
+ "cannot exit context: thread state references "
+ "a different context object");
+ return -1;
+ }
+
+ Py_SETREF(ts->context, (PyObject *)ctx->ctx_prev);
+ ts->context_ver++;
+
+ ctx->ctx_prev = NULL;
+ ctx->ctx_entered = 0;
+
+ return 0;
+}
+
+
+PyContextVar *
+PyContextVar_New(const char *name, PyObject *def)
+{
+ PyObject *pyname = PyUnicode_FromString(name);
+ if (pyname == NULL) {
+ return NULL;
+ }
+ return contextvar_new(pyname, def);
+}
+
+
+int
+PyContextVar_Get(PyContextVar *var, PyObject *def, PyObject **val)
+{
+ assert(PyContextVar_CheckExact(var));
+
+ PyThreadState *ts = PyThreadState_Get();
+ if (ts->context == NULL) {
+ goto not_found;
+ }
+
+ if (var->var_cached != NULL &&
+ var->var_cached_tsid == ts->id &&
+ var->var_cached_tsver == ts->context_ver)
+ {
+ *val = var->var_cached;
+ goto found;
+ }
+
+ assert(PyContext_CheckExact(ts->context));
+ PyHamtObject *vars = ((PyContext *)ts->context)->ctx_vars;
+
+ PyObject *found = NULL;
+ int res = _PyHamt_Find(vars, (PyObject*)var, &found);
+ if (res < 0) {
+ goto error;
+ }
+ if (res == 1) {
+ assert(found != NULL);
+ var->var_cached = found; /* borrow */
+ var->var_cached_tsid = ts->id;
+ var->var_cached_tsver = ts->context_ver;
+
+ *val = found;
+ goto found;
+ }
+
+not_found:
+ if (def == NULL) {
+ if (var->var_default != NULL) {
+ *val = var->var_default;
+ goto found;
+ }
+
+ *val = NULL;
+ goto found;
+ }
+ else {
+ *val = def;
+ goto found;
+ }
+
+found:
+ Py_XINCREF(*val);
+ return 0;
+
+error:
+ *val = NULL;
+ return -1;
+}
+
+
+PyContextToken *
+PyContextVar_Set(PyContextVar *var, PyObject *val)
+{
+ if (!PyContextVar_CheckExact(var)) {
+ PyErr_SetString(
+ PyExc_TypeError, "an instance of ContextVar was expected");
+ return NULL;
+ }
+
+ PyContext *ctx = context_get();
+ if (ctx == NULL) {
+ return NULL;
+ }
+
+ PyObject *old_val = NULL;
+ int found = _PyHamt_Find(ctx->ctx_vars, (PyObject *)var, &old_val);
+ if (found < 0) {
+ return NULL;
+ }
+
+ Py_XINCREF(old_val);
+ PyContextToken *tok = token_new(ctx, var, old_val);
+ Py_XDECREF(old_val);
+
+ if (contextvar_set(var, val)) {
+ Py_DECREF(tok);
+ return NULL;
+ }
+
+ return tok;
+}
+
+
+int
+PyContextVar_Reset(PyContextVar *var, PyContextToken *tok)
+{
+ if (tok->tok_used) {
+ PyErr_Format(PyExc_RuntimeError,
+ "%R has already been used once", tok);
+ return -1;
+ }
+
+ if (var != tok->tok_var) {
+ PyErr_Format(PyExc_ValueError,
+ "%R was created by a different ContextVar", tok);
+ return -1;
+ }
+
+ PyContext *ctx = context_get();
+ if (ctx != tok->tok_ctx) {
+ PyErr_Format(PyExc_ValueError,
+ "%R was created in a different Context", tok);
+ return -1;
+ }
+
+ tok->tok_used = 1;
+
+ if (tok->tok_oldval == NULL) {
+ return contextvar_del(var);
+ }
+ else {
+ return contextvar_set(var, tok->tok_oldval);
+ }
+}
+
+
+/////////////////////////// PyContext
+
+/*[clinic input]
+class _contextvars.Context "PyContext *" "&PyContext_Type"
+[clinic start generated code]*/
+/*[clinic end generated code: output=da39a3ee5e6b4b0d input=bdf87f8e0cb580e8]*/
+
+
+static inline PyContext *
+_context_alloc(void)
+{
+ PyContext *ctx;
+ if (ctx_freelist_len) {
+ ctx_freelist_len--;
+ ctx = ctx_freelist;
+ ctx_freelist = (PyContext *)ctx->ctx_weakreflist;
+ ctx->ctx_weakreflist = NULL;
+ _Py_NewReference((PyObject *)ctx);
+ }
+ else {
+ ctx = PyObject_GC_New(PyContext, &PyContext_Type);
+ if (ctx == NULL) {
+ return NULL;
+ }
+ }
+
+ ctx->ctx_vars = NULL;
+ ctx->ctx_prev = NULL;
+ ctx->ctx_entered = 0;
+ ctx->ctx_weakreflist = NULL;
+
+ return ctx;
+}
+
+
+static PyContext *
+context_new_empty(void)
+{
+ PyContext *ctx = _context_alloc();
+ if (ctx == NULL) {
+ return NULL;
+ }
+
+ ctx->ctx_vars = _PyHamt_New();
+ if (ctx->ctx_vars == NULL) {
+ Py_DECREF(ctx);
+ return NULL;
+ }
+
+ _PyObject_GC_TRACK(ctx);
+ return ctx;
+}
+
+
+static PyContext *
+context_new_from_vars(PyHamtObject *vars)
+{
+ PyContext *ctx = _context_alloc();
+ if (ctx == NULL) {
+ return NULL;
+ }
+
+ Py_INCREF(vars);
+ ctx->ctx_vars = vars;
+
+ _PyObject_GC_TRACK(ctx);
+ return ctx;
+}
+
+
+static inline PyContext *
+context_get(void)
+{
+ PyThreadState *ts = PyThreadState_Get();
+ PyContext *current_ctx = (PyContext *)ts->context;
+ if (current_ctx == NULL) {
+ current_ctx = context_new_empty();
+ if (current_ctx == NULL) {
+ return NULL;
+ }
+ ts->context = (PyObject *)current_ctx;
+ }
+ return current_ctx;
+}
+
+static int
+context_check_key_type(PyObject *key)
+{
+ if (!PyContextVar_CheckExact(key)) {
+ // abort();
+ PyErr_Format(PyExc_TypeError,
+ "a ContextVar key was expected, got %R", key);
+ return -1;
+ }
+ return 0;
+}
+
+static PyObject *
+context_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+ if (PyTuple_Size(args) || (kwds != NULL && PyDict_Size(kwds))) {
+ PyErr_SetString(
+ PyExc_TypeError, "Context() does not accept any arguments");
+ return NULL;
+ }
+ return (PyObject *)PyContext_New();
+}
+
+static int
+context_tp_clear(PyContext *self)
+{
+ Py_CLEAR(self->ctx_prev);
+ Py_CLEAR(self->ctx_vars);
+ return 0;
+}
+
+static int
+context_tp_traverse(PyContext *self, visitproc visit, void *arg)
+{
+ Py_VISIT(self->ctx_prev);
+ Py_VISIT(self->ctx_vars);
+ return 0;
+}
+
+static void
+context_tp_dealloc(PyContext *self)
+{
+ _PyObject_GC_UNTRACK(self);
+
+ if (self->ctx_weakreflist != NULL) {
+ PyObject_ClearWeakRefs((PyObject*)self);
+ }
+ (void)context_tp_clear(self);
+
+ if (ctx_freelist_len < CONTEXT_FREELIST_MAXLEN) {
+ ctx_freelist_len++;
+ self->ctx_weakreflist = (PyObject *)ctx_freelist;
+ ctx_freelist = self;
+ }
+ else {
+ Py_TYPE(self)->tp_free(self);
+ }
+}
+
+static PyObject *
+context_tp_iter(PyContext *self)
+{
+ return _PyHamt_NewIterKeys(self->ctx_vars);
+}
+
+static PyObject *
+context_tp_richcompare(PyObject *v, PyObject *w, int op)
+{
+ if (!PyContext_CheckExact(v) || !PyContext_CheckExact(w) ||
+ (op != Py_EQ && op != Py_NE))
+ {
+ Py_RETURN_NOTIMPLEMENTED;
+ }
+
+ int res = _PyHamt_Eq(
+ ((PyContext *)v)->ctx_vars, ((PyContext *)w)->ctx_vars);
+ if (res < 0) {
+ return NULL;
+ }
+
+ if (op == Py_NE) {
+ res = !res;
+ }
+
+ if (res) {
+ Py_RETURN_TRUE;
+ }
+ else {
+ Py_RETURN_FALSE;
+ }
+}
+
+static Py_ssize_t
+context_tp_len(PyContext *self)
+{
+ return _PyHamt_Len(self->ctx_vars);
+}
+
+static PyObject *
+context_tp_subscript(PyContext *self, PyObject *key)
+{
+ if (context_check_key_type(key)) {
+ return NULL;
+ }
+ PyObject *val = NULL;
+ int found = _PyHamt_Find(self->ctx_vars, key, &val);
+ if (found < 0) {
+ return NULL;
+ }
+ if (found == 0) {
+ PyErr_SetObject(PyExc_KeyError, key);
+ return NULL;
+ }
+ Py_INCREF(val);
+ return val;
+}
+
+static int
+context_tp_contains(PyContext *self, PyObject *key)
+{
+ if (context_check_key_type(key)) {
+ return -1;
+ }
+ PyObject *val = NULL;
+ return _PyHamt_Find(self->ctx_vars, key, &val);
+}
+
+
+/*[clinic input]
+_contextvars.Context.get
+ key: object
+ default: object = None
+ /
+[clinic start generated code]*/
+
+static PyObject *
+_contextvars_Context_get_impl(PyContext *self, PyObject *key,
+ PyObject *default_value)
+/*[clinic end generated code: output=0c54aa7664268189 input=8d4c33c8ecd6d769]*/
+{
+ if (context_check_key_type(key)) {
+ return NULL;
+ }
+
+ PyObject *val = NULL;
+ int found = _PyHamt_Find(self->ctx_vars, key, &val);
+ if (found < 0) {
+ return NULL;
+ }
+ if (found == 0) {
+ Py_INCREF(default_value);
+ return default_value;
+ }
+ Py_INCREF(val);
+ return val;
+}
+
+
+/*[clinic input]
+_contextvars.Context.items
+[clinic start generated code]*/
+
+static PyObject *
+_contextvars_Context_items_impl(PyContext *self)
+/*[clinic end generated code: output=fa1655c8a08502af input=2d570d1455004979]*/
+{
+ return _PyHamt_NewIterItems(self->ctx_vars);
+}
+
+
+/*[clinic input]
+_contextvars.Context.keys
+[clinic start generated code]*/
+
+static PyObject *
+_contextvars_Context_keys_impl(PyContext *self)
+/*[clinic end generated code: output=177227c6b63ec0e2 input=13005e142fbbf37d]*/
+{
+ return _PyHamt_NewIterKeys(self->ctx_vars);
+}
+
+
+/*[clinic input]
+_contextvars.Context.values
+[clinic start generated code]*/
+
+static PyObject *
+_contextvars_Context_values_impl(PyContext *self)
+/*[clinic end generated code: output=d286dabfc8db6dde input=c2cbc40a4470e905]*/
+{
+ return _PyHamt_NewIterValues(self->ctx_vars);
+}
+
+
+/*[clinic input]
+_contextvars.Context.copy
+[clinic start generated code]*/
+
+static PyObject *
+_contextvars_Context_copy_impl(PyContext *self)
+/*[clinic end generated code: output=30ba8896c4707a15 input=3e3fd72d598653ab]*/
+{
+ return (PyObject *)context_new_from_vars(self->ctx_vars);
+}
+
+
+static PyObject *
+context_run(PyContext *self, PyObject *const *args,
+ Py_ssize_t nargs, PyObject *kwnames)
+{
+ if (nargs < 1) {
+ PyErr_SetString(PyExc_TypeError,
+ "run() missing 1 required positional argument");
+ return NULL;
+ }
+
+ if (PyContext_Enter(self)) {
+ return NULL;
+ }
+
+ PyObject *call_result = _PyObject_FastCallKeywords(
+ args[0], args + 1, nargs - 1, kwnames);
+
+ if (PyContext_Exit(self)) {
+ return NULL;
+ }
+
+ return call_result;
+}
+
+
+static PyMethodDef PyContext_methods[] = {
+ _CONTEXTVARS_CONTEXT_GET_METHODDEF
+ _CONTEXTVARS_CONTEXT_ITEMS_METHODDEF
+ _CONTEXTVARS_CONTEXT_KEYS_METHODDEF
+ _CONTEXTVARS_CONTEXT_VALUES_METHODDEF
+ _CONTEXTVARS_CONTEXT_COPY_METHODDEF
+ {"run", (PyCFunction)context_run, METH_FASTCALL | METH_KEYWORDS, NULL},
+ {NULL, NULL}
+};
+
+static PySequenceMethods PyContext_as_sequence = {
+ 0, /* sq_length */
+ 0, /* sq_concat */
+ 0, /* sq_repeat */
+ 0, /* sq_item */
+ 0, /* sq_slice */
+ 0, /* sq_ass_item */
+ 0, /* sq_ass_slice */
+ (objobjproc)context_tp_contains, /* sq_contains */
+ 0, /* sq_inplace_concat */
+ 0, /* sq_inplace_repeat */
+};
+
+static PyMappingMethods PyContext_as_mapping = {
+ (lenfunc)context_tp_len, /* mp_length */
+ (binaryfunc)context_tp_subscript, /* mp_subscript */
+};
+
+PyTypeObject PyContext_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "Context",
+ sizeof(PyContext),
+ .tp_methods = PyContext_methods,
+ .tp_as_mapping = &PyContext_as_mapping,
+ .tp_as_sequence = &PyContext_as_sequence,
+ .tp_iter = (getiterfunc)context_tp_iter,
+ .tp_dealloc = (destructor)context_tp_dealloc,
+ .tp_getattro = PyObject_GenericGetAttr,
+ .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
+ .tp_richcompare = context_tp_richcompare,
+ .tp_traverse = (traverseproc)context_tp_traverse,
+ .tp_clear = (inquiry)context_tp_clear,
+ .tp_new = context_tp_new,
+ .tp_weaklistoffset = offsetof(PyContext, ctx_weakreflist),
+ .tp_hash = PyObject_HashNotImplemented,
+};
+
+
+/////////////////////////// ContextVar
+
+
+static int
+contextvar_set(PyContextVar *var, PyObject *val)
+{
+ var->var_cached = NULL;
+ PyThreadState *ts = PyThreadState_Get();
+
+ PyContext *ctx = context_get();
+ if (ctx == NULL) {
+ return -1;
+ }
+
+ PyHamtObject *new_vars = _PyHamt_Assoc(
+ ctx->ctx_vars, (PyObject *)var, val);
+ if (new_vars == NULL) {
+ return -1;
+ }
+
+ Py_SETREF(ctx->ctx_vars, new_vars);
+
+ var->var_cached = val; /* borrow */
+ var->var_cached_tsid = ts->id;
+ var->var_cached_tsver = ts->context_ver;
+ return 0;
+}
+
+static int
+contextvar_del(PyContextVar *var)
+{
+ var->var_cached = NULL;
+
+ PyContext *ctx = context_get();
+ if (ctx == NULL) {
+ return -1;
+ }
+
+ PyHamtObject *vars = ctx->ctx_vars;
+ PyHamtObject *new_vars = _PyHamt_Without(vars, (PyObject *)var);
+ if (new_vars == NULL) {
+ return -1;
+ }
+
+ if (vars == new_vars) {
+ Py_DECREF(new_vars);
+ PyErr_SetObject(PyExc_LookupError, (PyObject *)var);
+ return -1;
+ }
+
+ Py_SETREF(ctx->ctx_vars, new_vars);
+ return 0;
+}
+
+static Py_hash_t
+contextvar_generate_hash(void *addr, PyObject *name)
+{
+ /* Take hash of `name` and XOR it with the object's addr.
+
+ The structure of the tree is encoded in objects' hashes, which
+ means that sufficiently similar hashes would result in tall trees
+ with many Collision nodes. Which would, in turn, result in slower
+ get and set operations.
+
+ The XORing helps to ensure that:
+
+ (1) sequentially allocated ContextVar objects have
+ different hashes;
+
+ (2) context variables with equal names have
+ different hashes.
+ */
+
+ Py_hash_t name_hash = PyObject_Hash(name);
+ if (name_hash == -1) {
+ return -1;
+ }
+
+ Py_hash_t res = _Py_HashPointer(addr) ^ name_hash;
+ return res == -1 ? -2 : res;
+}
+
+static PyContextVar *
+contextvar_new(PyObject *name, PyObject *def)
+{
+ if (!PyUnicode_Check(name)) {
+ PyErr_SetString(PyExc_TypeError,
+ "context variable name must be a str");
+ return NULL;
+ }
+
+ PyContextVar *var = PyObject_GC_New(PyContextVar, &PyContextVar_Type);
+ if (var == NULL) {
+ return NULL;
+ }
+
+ var->var_hash = contextvar_generate_hash(var, name);
+ if (var->var_hash == -1) {
+ Py_DECREF(var);
+ return NULL;
+ }
+
+ Py_INCREF(name);
+ var->var_name = name;
+
+ Py_XINCREF(def);
+ var->var_default = def;
+
+ var->var_cached = NULL;
+ var->var_cached_tsid = 0;
+ var->var_cached_tsver = 0;
+
+ if (_PyObject_GC_IS_TRACKED(name) ||
+ (def != NULL && _PyObject_GC_IS_TRACKED(def)))
+ {
+ PyObject_GC_Track(var);
+ }
+ return var;
+}
+
+
+/*[clinic input]
+class _contextvars.ContextVar "PyContextVar *" "&PyContextVar_Type"
+[clinic start generated code]*/
+/*[clinic end generated code: output=da39a3ee5e6b4b0d input=445da935fa8883c3]*/
+
+
+static PyObject *
+contextvar_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+ static char *kwlist[] = {"", "default", NULL};
+ PyObject *name;
+ PyObject *def = NULL;
+
+ if (!PyArg_ParseTupleAndKeywords(
+ args, kwds, "O|$O:ContextVar", kwlist, &name, &def))
+ {
+ return NULL;
+ }
+
+ return (PyObject *)contextvar_new(name, def);
+}
+
+static int
+contextvar_tp_clear(PyContextVar *self)
+{
+ Py_CLEAR(self->var_name);
+ Py_CLEAR(self->var_default);
+ self->var_cached = NULL;
+ self->var_cached_tsid = 0;
+ self->var_cached_tsver = 0;
+ return 0;
+}
+
+static int
+contextvar_tp_traverse(PyContextVar *self, visitproc visit, void *arg)
+{
+ Py_VISIT(self->var_name);
+ Py_VISIT(self->var_default);
+ return 0;
+}
+
+static void
+contextvar_tp_dealloc(PyContextVar *self)
+{
+ PyObject_GC_UnTrack(self);
+ (void)contextvar_tp_clear(self);
+ Py_TYPE(self)->tp_free(self);
+}
+
+static Py_hash_t
+contextvar_tp_hash(PyContextVar *self)
+{
+ return self->var_hash;
+}
+
+static PyObject *
+contextvar_tp_repr(PyContextVar *self)
+{
+ _PyUnicodeWriter writer;
+
+ _PyUnicodeWriter_Init(&writer);
+
+ if (_PyUnicodeWriter_WriteASCIIString(
+ &writer, "<ContextVar name=", 17) < 0)
+ {
+ goto error;
+ }
+
+ PyObject *name = PyObject_Repr(self->var_name);
+ if (name == NULL) {
+ goto error;
+ }
+ if (_PyUnicodeWriter_WriteStr(&writer, name) < 0) {
+ Py_DECREF(name);
+ goto error;
+ }
+ Py_DECREF(name);
+
+ if (self->var_default != NULL) {
+ if (_PyUnicodeWriter_WriteASCIIString(&writer, " default=", 9) < 0) {
+ goto error;
+ }
+
+ PyObject *def = PyObject_Repr(self->var_default);
+ if (def == NULL) {
+ goto error;
+ }
+ if (_PyUnicodeWriter_WriteStr(&writer, def) < 0) {
+ Py_DECREF(def);
+ goto error;
+ }
+ Py_DECREF(def);
+ }
+
+ PyObject *addr = PyUnicode_FromFormat(" at %p>", self);
+ if (addr == NULL) {
+ goto error;
+ }
+ if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) {
+ Py_DECREF(addr);
+ goto error;
+ }
+ Py_DECREF(addr);
+
+ return _PyUnicodeWriter_Finish(&writer);
+
+error:
+ _PyUnicodeWriter_Dealloc(&writer);
+ return NULL;
+}
+
+
+/*[clinic input]
+_contextvars.ContextVar.get
+ default: object = NULL
+ /
+[clinic start generated code]*/
+
+static PyObject *
+_contextvars_ContextVar_get_impl(PyContextVar *self, PyObject *default_value)
+/*[clinic end generated code: output=0746bd0aa2ced7bf input=8d002b02eebbb247]*/
+{
+ if (!PyContextVar_CheckExact(self)) {
+ PyErr_SetString(
+ PyExc_TypeError, "an instance of ContextVar was expected");
+ return NULL;
+ }
+
+ PyObject *val;
+ if (PyContextVar_Get(self, default_value, &val) < 0) {
+ return NULL;
+ }
+
+ if (val == NULL) {
+ PyErr_SetObject(PyExc_LookupError, (PyObject *)self);
+ return NULL;
+ }
+
+ return val;
+}
+
+/*[clinic input]
+_contextvars.ContextVar.set
+ value: object
+ /
+[clinic start generated code]*/
+
+static PyObject *
+_contextvars_ContextVar_set(PyContextVar *self, PyObject *value)
+/*[clinic end generated code: output=446ed5e820d6d60b input=a2d88f57c6d86f7c]*/
+{
+ return (PyObject *)PyContextVar_Set(self, value);
+}
+
+/*[clinic input]
+_contextvars.ContextVar.reset
+ token: object
+ /
+[clinic start generated code]*/
+
+static PyObject *
+_contextvars_ContextVar_reset(PyContextVar *self, PyObject *token)
+/*[clinic end generated code: output=d4ee34d0742d62ee input=4c871b6f1f31a65f]*/
+{
+ if (!PyContextToken_CheckExact(token)) {
+ PyErr_Format(PyExc_TypeError,
+ "expected an instance of Token, got %R", token);
+ return NULL;
+ }
+
+ if (PyContextVar_Reset(self, (PyContextToken *)token)) {
+ return NULL;
+ }
+
+ Py_RETURN_NONE;
+}
+
+
+static PyObject *
+contextvar_cls_getitem(PyObject *self, PyObject *args)
+{
+ Py_RETURN_NONE;
+}
+
+
+static PyMethodDef PyContextVar_methods[] = {
+ _CONTEXTVARS_CONTEXTVAR_GET_METHODDEF
+ _CONTEXTVARS_CONTEXTVAR_SET_METHODDEF
+ _CONTEXTVARS_CONTEXTVAR_RESET_METHODDEF
+ {"__class_getitem__", contextvar_cls_getitem,
+ METH_VARARGS | METH_STATIC, NULL},
+ {NULL, NULL}
+};
+
+PyTypeObject PyContextVar_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "ContextVar",
+ sizeof(PyContextVar),
+ .tp_methods = PyContextVar_methods,
+ .tp_dealloc = (destructor)contextvar_tp_dealloc,
+ .tp_getattro = PyObject_GenericGetAttr,
+ .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
+ .tp_traverse = (traverseproc)contextvar_tp_traverse,
+ .tp_clear = (inquiry)contextvar_tp_clear,
+ .tp_new = contextvar_tp_new,
+ .tp_free = PyObject_GC_Del,
+ .tp_hash = (hashfunc)contextvar_tp_hash,
+ .tp_repr = (reprfunc)contextvar_tp_repr,
+};
+
+
+/////////////////////////// Token
+
+static PyObject * get_token_missing(void);
+
+
+/*[clinic input]
+class _contextvars.Token "PyContextToken *" "&PyContextToken_Type"
+[clinic start generated code]*/
+/*[clinic end generated code: output=da39a3ee5e6b4b0d input=338a5e2db13d3f5b]*/
+
+
+static PyObject *
+token_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+ PyErr_SetString(PyExc_RuntimeError,
+ "Tokens can only be created by ContextVars");
+ return NULL;
+}
+
+static int
+token_tp_clear(PyContextToken *self)
+{
+ Py_CLEAR(self->tok_ctx);
+ Py_CLEAR(self->tok_var);
+ Py_CLEAR(self->tok_oldval);
+ return 0;
+}
+
+static int
+token_tp_traverse(PyContextToken *self, visitproc visit, void *arg)
+{
+ Py_VISIT(self->tok_ctx);
+ Py_VISIT(self->tok_var);
+ Py_VISIT(self->tok_oldval);
+ return 0;
+}
+
+static void
+token_tp_dealloc(PyContextToken *self)
+{
+ PyObject_GC_UnTrack(self);
+ (void)token_tp_clear(self);
+ Py_TYPE(self)->tp_free(self);
+}
+
+static PyObject *
+token_tp_repr(PyContextToken *self)
+{
+ _PyUnicodeWriter writer;
+
+ _PyUnicodeWriter_Init(&writer);
+
+ if (_PyUnicodeWriter_WriteASCIIString(&writer, "<Token", 6) < 0) {
+ goto error;
+ }
+
+ if (self->tok_used) {
+ if (_PyUnicodeWriter_WriteASCIIString(&writer, " used", 5) < 0) {
+ goto error;
+ }
+ }
+
+ if (_PyUnicodeWriter_WriteASCIIString(&writer, " var=", 5) < 0) {
+ goto error;
+ }
+
+ PyObject *var = PyObject_Repr((PyObject *)self->tok_var);
+ if (var == NULL) {
+ goto error;
+ }
+ if (_PyUnicodeWriter_WriteStr(&writer, var) < 0) {
+ Py_DECREF(var);
+ goto error;
+ }
+ Py_DECREF(var);
+
+ PyObject *addr = PyUnicode_FromFormat(" at %p>", self);
+ if (addr == NULL) {
+ goto error;
+ }
+ if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) {
+ Py_DECREF(addr);
+ goto error;
+ }
+ Py_DECREF(addr);
+
+ return _PyUnicodeWriter_Finish(&writer);
+
+error:
+ _PyUnicodeWriter_Dealloc(&writer);
+ return NULL;
+}
+
+static PyObject *
+token_get_var(PyContextToken *self)
+{
+ Py_INCREF(self->tok_var);
+ return (PyObject *)self->tok_var;
+}
+
+static PyObject *
+token_get_old_value(PyContextToken *self)
+{
+ if (self->tok_oldval == NULL) {
+ return get_token_missing();
+ }
+
+ Py_INCREF(self->tok_oldval);
+ return self->tok_oldval;
+}
+
+static PyGetSetDef PyContextTokenType_getsetlist[] = {
+ {"var", (getter)token_get_var, NULL, NULL},
+ {"old_value", (getter)token_get_old_value, NULL, NULL},
+ {NULL}
+};
+
+PyTypeObject PyContextToken_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "Token",
+ sizeof(PyContextToken),
+ .tp_getset = PyContextTokenType_getsetlist,
+ .tp_dealloc = (destructor)token_tp_dealloc,
+ .tp_getattro = PyObject_GenericGetAttr,
+ .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
+ .tp_traverse = (traverseproc)token_tp_traverse,
+ .tp_clear = (inquiry)token_tp_clear,
+ .tp_new = token_tp_new,
+ .tp_free = PyObject_GC_Del,
+ .tp_hash = PyObject_HashNotImplemented,
+ .tp_repr = (reprfunc)token_tp_repr,
+};
+
+static PyContextToken *
+token_new(PyContext *ctx, PyContextVar *var, PyObject *val)
+{
+ PyContextToken *tok = PyObject_GC_New(PyContextToken, &PyContextToken_Type);
+ if (tok == NULL) {
+ return NULL;
+ }
+
+ Py_INCREF(ctx);
+ tok->tok_ctx = ctx;
+
+ Py_INCREF(var);
+ tok->tok_var = var;
+
+ Py_XINCREF(val);
+ tok->tok_oldval = val;
+
+ tok->tok_used = 0;
+
+ PyObject_GC_Track(tok);
+ return tok;
+}
+
+
+/////////////////////////// Token.MISSING
+
+
+static PyObject *_token_missing;
+
+
+typedef struct {
+ PyObject_HEAD
+} PyContextTokenMissing;
+
+
+static PyObject *
+context_token_missing_tp_repr(PyObject *self)
+{
+ return PyUnicode_FromString("<Token.MISSING>");
+}
+
+
+PyTypeObject PyContextTokenMissing_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "Token.MISSING",
+ sizeof(PyContextTokenMissing),
+ .tp_getattro = PyObject_GenericGetAttr,
+ .tp_flags = Py_TPFLAGS_DEFAULT,
+ .tp_repr = context_token_missing_tp_repr,
+};
+
+
+static PyObject *
+get_token_missing(void)
+{
+ if (_token_missing != NULL) {
+ Py_INCREF(_token_missing);
+ return _token_missing;
+ }
+
+ _token_missing = (PyObject *)PyObject_New(
+ PyContextTokenMissing, &PyContextTokenMissing_Type);
+ if (_token_missing == NULL) {
+ return NULL;
+ }
+
+ Py_INCREF(_token_missing);
+ return _token_missing;
+}
+
+
+///////////////////////////
+
+
+int
+PyContext_ClearFreeList(void)
+{
+ int size = ctx_freelist_len;
+ while (ctx_freelist_len) {
+ PyContext *ctx = ctx_freelist;
+ ctx_freelist = (PyContext *)ctx->ctx_weakreflist;
+ ctx->ctx_weakreflist = NULL;
+ PyObject_GC_Del(ctx);
+ ctx_freelist_len--;
+ }
+ return size;
+}
+
+
+void
+_PyContext_Fini(void)
+{
+ Py_CLEAR(_token_missing);
+ (void)PyContext_ClearFreeList();
+ (void)_PyHamt_Fini();
+}
+
+
+int
+_PyContext_Init(void)
+{
+ if (!_PyHamt_Init()) {
+ return 0;
+ }
+
+ if ((PyType_Ready(&PyContext_Type) < 0) ||
+ (PyType_Ready(&PyContextVar_Type) < 0) ||
+ (PyType_Ready(&PyContextToken_Type) < 0) ||
+ (PyType_Ready(&PyContextTokenMissing_Type) < 0))
+ {
+ return 0;
+ }
+
+ PyObject *missing = get_token_missing();
+ if (PyDict_SetItemString(
+ PyContextToken_Type.tp_dict, "MISSING", missing))
+ {
+ Py_DECREF(missing);
+ return 0;
+ }
+ Py_DECREF(missing);
+
+ return 1;
+}
diff --git a/Python/hamt.c b/Python/hamt.c
new file mode 100644
index 0000000..8ba5082
--- /dev/null
+++ b/Python/hamt.c
@@ -0,0 +1,2982 @@
+#include "Python.h"
+
+#include "structmember.h"
+#include "internal/pystate.h"
+#include "internal/hamt.h"
+
+/* popcnt support in Visual Studio */
+#ifdef _MSC_VER
+#include <intrin.h>
+#endif
+
+/*
+This file provides an implemention of an immutable mapping using the
+Hash Array Mapped Trie (or HAMT) datastructure.
+
+This design allows to have:
+
+1. Efficient copy: immutable mappings can be copied by reference,
+ making it an O(1) operation.
+
+2. Efficient mutations: due to structural sharing, only a portion of
+ the trie needs to be copied when the collection is mutated. The
+ cost of set/delete operations is O(log N).
+
+3. Efficient lookups: O(log N).
+
+(where N is number of key/value items in the immutable mapping.)
+
+
+HAMT
+====
+
+The core idea of HAMT is that the shape of the trie is encoded into the
+hashes of keys.
+
+Say we want to store a K/V pair in our mapping. First, we calculate the
+hash of K, let's say it's 19830128, or in binary:
+
+ 0b1001011101001010101110000 = 19830128
+
+Now let's partition this bit representation of the hash into blocks of
+5 bits each:
+
+ 0b00_00000_10010_11101_00101_01011_10000 = 19830128
+ (6) (5) (4) (3) (2) (1)
+
+Each block of 5 bits represents a number betwen 0 and 31. So if we have
+a tree that consists of nodes, each of which is an array of 32 pointers,
+those 5-bit blocks will encode a position on a single tree level.
+
+For example, storing the key K with hash 19830128, results in the following
+tree structure:
+
+ (array of 32 pointers)
+ +---+ -- +----+----+----+ -- +----+
+ root node | 0 | .. | 15 | 16 | 17 | .. | 31 | 0b10000 = 16 (1)
+ (level 1) +---+ -- +----+----+----+ -- +----+
+ |
+ +---+ -- +----+----+----+ -- +----+
+ a 2nd level node | 0 | .. | 10 | 11 | 12 | .. | 31 | 0b01011 = 11 (2)
+ +---+ -- +----+----+----+ -- +----+
+ |
+ +---+ -- +----+----+----+ -- +----+
+ a 3rd level node | 0 | .. | 04 | 05 | 06 | .. | 31 | 0b01011 = 5 (3)
+ +---+ -- +----+----+----+ -- +----+
+ |
+ +---+ -- +----+----+----+----+
+ a 4th level node | 0 | .. | 04 | 29 | 30 | 31 | 0b11101 = 29 (4)
+ +---+ -- +----+----+----+----+
+ |
+ +---+ -- +----+----+----+ -- +----+
+ a 5th level node | 0 | .. | 17 | 18 | 19 | .. | 31 | 0b10010 = 18 (5)
+ +---+ -- +----+----+----+ -- +----+
+ |
+ +--------------+
+ |
+ +---+ -- +----+----+----+ -- +----+
+ a 6th level node | 0 | .. | 15 | 16 | 17 | .. | 31 | 0b00000 = 0 (6)
+ +---+ -- +----+----+----+ -- +----+
+ |
+ V -- our value (or collision)
+
+To rehash: for a K/V pair, the hash of K encodes where in the tree V will
+be stored.
+
+To optimize memory footprint and handle hash collisions, our implementation
+uses three different types of nodes:
+
+ * A Bitmap node;
+ * An Array node;
+ * A Collision node.
+
+Because we implement an immutable dictionary, our nodes are also
+immutable. Therefore, when we need to modify a node, we copy it, and
+do that modification to the copy.
+
+
+Array Nodes
+-----------
+
+These nodes are very simple. Essentially they are arrays of 32 pointers
+we used to illustrate the high-level idea in the previous section.
+
+We use Array nodes only when we need to store more than 16 pointers
+in a single node.
+
+Array nodes do not store key objects or value objects. They are used
+only as an indirection level - their pointers point to other nodes in
+the tree.
+
+
+Bitmap Node
+-----------
+
+Allocating a new 32-pointers array for every node of our tree would be
+very expensive. Unless we store millions of keys, most of tree nodes would
+be very sparse.
+
+When we have less than 16 elements in a node, we don't want to use the
+Array node, that would mean that we waste a lot of memory. Instead,
+we can use bitmap compression and can have just as many pointers
+as we need!
+
+Bitmap nodes consist of two fields:
+
+1. An array of pointers. If a Bitmap node holds N elements, the
+ array will be of N pointers.
+
+2. A 32bit integer -- a bitmap field. If an N-th bit is set in the
+ bitmap, it means that the node has an N-th element.
+
+For example, say we need to store a 3 elements sparse array:
+
+ +---+ -- +---+ -- +----+ -- +----+
+ | 0 | .. | 4 | .. | 11 | .. | 17 |
+ +---+ -- +---+ -- +----+ -- +----+
+ | | |
+ o1 o2 o3
+
+We allocate a three-pointer Bitmap node. Its bitmap field will be
+then set to:
+
+ 0b_00100_00010_00000_10000 == (1 << 17) | (1 << 11) | (1 << 4)
+
+To check if our Bitmap node has an I-th element we can do:
+
+ bitmap & (1 << I)
+
+
+And here's a formula to calculate a position in our pointer array
+which would correspond to an I-th element:
+
+ popcount(bitmap & ((1 << I) - 1))
+
+
+Let's break it down:
+
+ * `popcount` is a function that returns a number of bits set to 1;
+
+ * `((1 << I) - 1)` is a mask to filter the bitmask to contain bits
+ set to the *right* of our bit.
+
+
+So for our 17, 11, and 4 indexes:
+
+ * bitmap & ((1 << 17) - 1) == 0b100000010000 => 2 bits are set => index is 2.
+
+ * bitmap & ((1 << 11) - 1) == 0b10000 => 1 bit is set => index is 1.
+
+ * bitmap & ((1 << 4) - 1) == 0b0 => 0 bits are set => index is 0.
+
+
+To conclude: Bitmap nodes are just like Array nodes -- they can store
+a number of pointers, but use bitmap compression to eliminate unused
+pointers.
+
+
+Bitmap nodes have two pointers for each item:
+
+ +----+----+----+----+ -- +----+----+
+ | k1 | v1 | k2 | v2 | .. | kN | vN |
+ +----+----+----+----+ -- +----+----+
+
+When kI == NULL, vI points to another tree level.
+
+When kI != NULL, the actual key object is stored in kI, and its
+value is stored in vI.
+
+
+Collision Nodes
+---------------
+
+Collision nodes are simple arrays of pointers -- two pointers per
+key/value. When there's a hash collision, say for k1/v1 and k2/v2
+we have `hash(k1)==hash(k2)`. Then our collision node will be:
+
+ +----+----+----+----+
+ | k1 | v1 | k2 | v2 |
+ +----+----+----+----+
+
+
+Tree Structure
+--------------
+
+All nodes are PyObjects.
+
+The `PyHamtObject` object has a pointer to the root node (h_root),
+and has a length field (h_count).
+
+High-level functions accept a PyHamtObject object and dispatch to
+lower-level functions depending on what kind of node h_root points to.
+
+
+Operations
+==========
+
+There are three fundamental operations on an immutable dictionary:
+
+1. "o.assoc(k, v)" will return a new immutable dictionary, that will be
+ a copy of "o", but with the "k/v" item set.
+
+ Functions in this file:
+
+ hamt_node_assoc, hamt_node_bitmap_assoc,
+ hamt_node_array_assoc, hamt_node_collision_assoc
+
+ `hamt_node_assoc` function accepts a node object, and calls
+ other functions depending on its actual type.
+
+2. "o.find(k)" will lookup key "k" in "o".
+
+ Functions:
+
+ hamt_node_find, hamt_node_bitmap_find,
+ hamt_node_array_find, hamt_node_collision_find
+
+3. "o.without(k)" will return a new immutable dictionary, that will be
+ a copy of "o", buth without the "k" key.
+
+ Functions:
+
+ hamt_node_without, hamt_node_bitmap_without,
+ hamt_node_array_without, hamt_node_collision_without
+
+
+Further Reading
+===============
+
+1. http://blog.higher-order.net/2009/09/08/understanding-clojures-persistenthashmap-deftwice.html
+
+2. http://blog.higher-order.net/2010/08/16/assoc-and-clojures-persistenthashmap-part-ii.html
+
+3. Clojure's PersistentHashMap implementation:
+ https://github.com/clojure/clojure/blob/master/src/jvm/clojure/lang/PersistentHashMap.java
+
+
+Debug
+=====
+
+The HAMT datatype is accessible for testing purposes under the
+`_testcapi` module:
+
+ >>> from _testcapi import hamt
+ >>> h = hamt()
+ >>> h2 = h.set('a', 2)
+ >>> h3 = h2.set('b', 3)
+ >>> list(h3)
+ ['a', 'b']
+
+When CPython is built in debug mode, a '__dump__()' method is available
+to introspect the tree:
+
+ >>> print(h3.__dump__())
+ HAMT(len=2):
+ BitmapNode(size=4 count=2 bitmap=0b110 id=0x10eb9d9e8):
+ 'a': 2
+ 'b': 3
+*/
+
+
+#define IS_ARRAY_NODE(node) (Py_TYPE(node) == &_PyHamt_ArrayNode_Type)
+#define IS_BITMAP_NODE(node) (Py_TYPE(node) == &_PyHamt_BitmapNode_Type)
+#define IS_COLLISION_NODE(node) (Py_TYPE(node) == &_PyHamt_CollisionNode_Type)
+
+
+/* Return type for 'find' (lookup a key) functions.
+
+ * F_ERROR - an error occurred;
+ * F_NOT_FOUND - the key was not found;
+ * F_FOUND - the key was found.
+*/
+typedef enum {F_ERROR, F_NOT_FOUND, F_FOUND} hamt_find_t;
+
+
+/* Return type for 'without' (delete a key) functions.
+
+ * W_ERROR - an error occurred;
+ * W_NOT_FOUND - the key was not found: there's nothing to delete;
+ * W_EMPTY - the key was found: the node/tree would be empty
+ if the key is deleted;
+ * W_NEWNODE - the key was found: a new node/tree is returned
+ without that key.
+*/
+typedef enum {W_ERROR, W_NOT_FOUND, W_EMPTY, W_NEWNODE} hamt_without_t;
+
+
+/* Low-level iterator protocol type.
+
+ * I_ITEM - a new item has been yielded;
+ * I_END - the whole tree was visited (similar to StopIteration).
+*/
+typedef enum {I_ITEM, I_END} hamt_iter_t;
+
+
+#define HAMT_ARRAY_NODE_SIZE 32
+
+
+typedef struct {
+ PyObject_HEAD
+ PyHamtNode *a_array[HAMT_ARRAY_NODE_SIZE];
+ Py_ssize_t a_count;
+} PyHamtNode_Array;
+
+
+typedef struct {
+ PyObject_VAR_HEAD
+ uint32_t b_bitmap;
+ PyObject *b_array[1];
+} PyHamtNode_Bitmap;
+
+
+typedef struct {
+ PyObject_VAR_HEAD
+ int32_t c_hash;
+ PyObject *c_array[1];
+} PyHamtNode_Collision;
+
+
+static PyHamtNode_Bitmap *_empty_bitmap_node;
+static PyHamtObject *_empty_hamt;
+
+
+static PyHamtObject *
+hamt_alloc(void);
+
+static PyHamtNode *
+hamt_node_assoc(PyHamtNode *node,
+ uint32_t shift, int32_t hash,
+ PyObject *key, PyObject *val, int* added_leaf);
+
+static hamt_without_t
+hamt_node_without(PyHamtNode *node,
+ uint32_t shift, int32_t hash,
+ PyObject *key,
+ PyHamtNode **new_node);
+
+static hamt_find_t
+hamt_node_find(PyHamtNode *node,
+ uint32_t shift, int32_t hash,
+ PyObject *key, PyObject **val);
+
+#ifdef Py_DEBUG
+static int
+hamt_node_dump(PyHamtNode *node,
+ _PyUnicodeWriter *writer, int level);
+#endif
+
+static PyHamtNode *
+hamt_node_array_new(Py_ssize_t);
+
+static PyHamtNode *
+hamt_node_collision_new(int32_t hash, Py_ssize_t size);
+
+static inline Py_ssize_t
+hamt_node_collision_count(PyHamtNode_Collision *node);
+
+
+#ifdef Py_DEBUG
+static void
+_hamt_node_array_validate(void *o)
+{
+ assert(IS_ARRAY_NODE(o));
+ PyHamtNode_Array *node = (PyHamtNode_Array*)(o);
+ Py_ssize_t i = 0, count = 0;
+ for (; i < HAMT_ARRAY_NODE_SIZE; i++) {
+ if (node->a_array[i] != NULL) {
+ count++;
+ }
+ }
+ assert(count == node->a_count);
+}
+
+#define VALIDATE_ARRAY_NODE(NODE) \
+ do { _hamt_node_array_validate(NODE); } while (0);
+#else
+#define VALIDATE_ARRAY_NODE(NODE)
+#endif
+
+
+/* Returns -1 on error */
+static inline int32_t
+hamt_hash(PyObject *o)
+{
+ Py_hash_t hash = PyObject_Hash(o);
+
+#if SIZEOF_PY_HASH_T <= 4
+ return hash;
+#else
+ if (hash == -1) {
+ /* exception */
+ return -1;
+ }
+
+ /* While it's suboptimal to reduce Python's 64 bit hash to
+ 32 bits via XOR, it seems that the resulting hash function
+ is good enough (this is also how Long type is hashed in Java.)
+ Storing 10, 100, 1000 Python strings results in a relatively
+ shallow and uniform tree structure.
+
+ Please don't change this hashing algorithm, as there are many
+ tests that test some exact tree shape to cover all code paths.
+ */
+ int32_t xored = (int32_t)(hash & 0xffffffffl) ^ (int32_t)(hash >> 32);
+ return xored == -1 ? -2 : xored;
+#endif
+}
+
+static inline uint32_t
+hamt_mask(int32_t hash, uint32_t shift)
+{
+ return (((uint32_t)hash >> shift) & 0x01f);
+}
+
+static inline uint32_t
+hamt_bitpos(int32_t hash, uint32_t shift)
+{
+ return (uint32_t)1 << hamt_mask(hash, shift);
+}
+
+static inline uint32_t
+hamt_bitcount(uint32_t i)
+{
+#if defined(__GNUC__) && (__GNUC__ > 4)
+ return (uint32_t)__builtin_popcountl(i);
+#elif defined(__clang__) && (__clang_major__ > 3)
+ return (uint32_t)__builtin_popcountl(i);
+#elif defined(_MSC_VER)
+ return (uint32_t)__popcnt(i);
+#else
+ /* https://graphics.stanford.edu/~seander/bithacks.html */
+ i = i - ((i >> 1) & 0x55555555);
+ i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
+ return ((i + (i >> 4) & 0xF0F0F0F) * 0x1010101) >> 24;
+#endif
+}
+
+static inline uint32_t
+hamt_bitindex(uint32_t bitmap, uint32_t bit)
+{
+ return hamt_bitcount(bitmap & (bit - 1));
+}
+
+
+/////////////////////////////////// Dump Helpers
+#ifdef Py_DEBUG
+
+static int
+_hamt_dump_ident(_PyUnicodeWriter *writer, int level)
+{
+ /* Write `' ' * level` to the `writer` */
+ PyObject *str = NULL;
+ PyObject *num = NULL;
+ PyObject *res = NULL;
+ int ret = -1;
+
+ str = PyUnicode_FromString(" ");
+ if (str == NULL) {
+ goto error;
+ }
+
+ num = PyLong_FromLong((long)level);
+ if (num == NULL) {
+ goto error;
+ }
+
+ res = PyNumber_Multiply(str, num);
+ if (res == NULL) {
+ goto error;
+ }
+
+ ret = _PyUnicodeWriter_WriteStr(writer, res);
+
+error:
+ Py_XDECREF(res);
+ Py_XDECREF(str);
+ Py_XDECREF(num);
+ return ret;
+}
+
+static int
+_hamt_dump_format(_PyUnicodeWriter *writer, const char *format, ...)
+{
+ /* A convenient helper combining _PyUnicodeWriter_WriteStr and
+ PyUnicode_FromFormatV.
+ */
+ PyObject* msg;
+ int ret;
+
+ va_list vargs;
+#ifdef HAVE_STDARG_PROTOTYPES
+ va_start(vargs, format);
+#else
+ va_start(vargs);
+#endif
+ msg = PyUnicode_FromFormatV(format, vargs);
+ va_end(vargs);
+
+ if (msg == NULL) {
+ return -1;
+ }
+
+ ret = _PyUnicodeWriter_WriteStr(writer, msg);
+ Py_DECREF(msg);
+ return ret;
+}
+
+#endif /* Py_DEBUG */
+/////////////////////////////////// Bitmap Node
+
+
+static PyHamtNode *
+hamt_node_bitmap_new(Py_ssize_t size)
+{
+ /* Create a new bitmap node of size 'size' */
+
+ PyHamtNode_Bitmap *node;
+ Py_ssize_t i;
+
+ assert(size >= 0);
+ assert(size % 2 == 0);
+
+ if (size == 0 && _empty_bitmap_node != NULL) {
+ Py_INCREF(_empty_bitmap_node);
+ return (PyHamtNode *)_empty_bitmap_node;
+ }
+
+ /* No freelist; allocate a new bitmap node */
+ node = PyObject_GC_NewVar(
+ PyHamtNode_Bitmap, &_PyHamt_BitmapNode_Type, size);
+ if (node == NULL) {
+ return NULL;
+ }
+
+ Py_SIZE(node) = size;
+
+ for (i = 0; i < size; i++) {
+ node->b_array[i] = NULL;
+ }
+
+ node->b_bitmap = 0;
+
+ _PyObject_GC_TRACK(node);
+
+ if (size == 0 && _empty_bitmap_node == NULL) {
+ /* Since bitmap nodes are immutable, we can cache the instance
+ for size=0 and reuse it whenever we need an empty bitmap node.
+ */
+ _empty_bitmap_node = node;
+ Py_INCREF(_empty_bitmap_node);
+ }
+
+ return (PyHamtNode *)node;
+}
+
+static inline Py_ssize_t
+hamt_node_bitmap_count(PyHamtNode_Bitmap *node)
+{
+ return Py_SIZE(node) / 2;
+}
+
+static PyHamtNode_Bitmap *
+hamt_node_bitmap_clone(PyHamtNode_Bitmap *node)
+{
+ /* Clone a bitmap node; return a new one with the same child notes. */
+
+ PyHamtNode_Bitmap *clone;
+ Py_ssize_t i;
+
+ clone = (PyHamtNode_Bitmap *)hamt_node_bitmap_new(Py_SIZE(node));
+ if (clone == NULL) {
+ return NULL;
+ }
+
+ for (i = 0; i < Py_SIZE(node); i++) {
+ Py_XINCREF(node->b_array[i]);
+ clone->b_array[i] = node->b_array[i];
+ }
+
+ clone->b_bitmap = node->b_bitmap;
+ return clone;
+}
+
+static PyHamtNode_Bitmap *
+hamt_node_bitmap_clone_without(PyHamtNode_Bitmap *o, uint32_t bit)
+{
+ assert(bit & o->b_bitmap);
+ assert(hamt_node_bitmap_count(o) > 1);
+
+ PyHamtNode_Bitmap *new = (PyHamtNode_Bitmap *)hamt_node_bitmap_new(
+ Py_SIZE(o) - 2);
+ if (new == NULL) {
+ return NULL;
+ }
+
+ uint32_t idx = hamt_bitindex(o->b_bitmap, bit);
+ uint32_t key_idx = 2 * idx;
+ uint32_t val_idx = key_idx + 1;
+ uint32_t i;
+
+ for (i = 0; i < key_idx; i++) {
+ Py_XINCREF(o->b_array[i]);
+ new->b_array[i] = o->b_array[i];
+ }
+
+ for (i = val_idx + 1; i < Py_SIZE(o); i++) {
+ Py_XINCREF(o->b_array[i]);
+ new->b_array[i - 2] = o->b_array[i];
+ }
+
+ new->b_bitmap = o->b_bitmap & ~bit;
+ return new;
+}
+
+static PyHamtNode *
+hamt_node_new_bitmap_or_collision(uint32_t shift,
+ PyObject *key1, PyObject *val1,
+ int32_t key2_hash,
+ PyObject *key2, PyObject *val2)
+{
+ /* Helper method. Creates a new node for key1/val and key2/val2
+ pairs.
+
+ If key1 hash is equal to the hash of key2, a Collision node
+ will be created. If they are not equal, a Bitmap node is
+ created.
+ */
+
+ int32_t key1_hash = hamt_hash(key1);
+ if (key1_hash == -1) {
+ return NULL;
+ }
+
+ if (key1_hash == key2_hash) {
+ PyHamtNode_Collision *n;
+ n = (PyHamtNode_Collision *)hamt_node_collision_new(key1_hash, 4);
+ if (n == NULL) {
+ return NULL;
+ }
+
+ Py_INCREF(key1);
+ n->c_array[0] = key1;
+ Py_INCREF(val1);
+ n->c_array[1] = val1;
+
+ Py_INCREF(key2);
+ n->c_array[2] = key2;
+ Py_INCREF(val2);
+ n->c_array[3] = val2;
+
+ return (PyHamtNode *)n;
+ }
+ else {
+ int added_leaf = 0;
+ PyHamtNode *n = hamt_node_bitmap_new(0);
+ if (n == NULL) {
+ return NULL;
+ }
+
+ PyHamtNode *n2 = hamt_node_assoc(
+ n, shift, key1_hash, key1, val1, &added_leaf);
+ Py_DECREF(n);
+ if (n2 == NULL) {
+ return NULL;
+ }
+
+ n = hamt_node_assoc(n2, shift, key2_hash, key2, val2, &added_leaf);
+ Py_DECREF(n2);
+ if (n == NULL) {
+ return NULL;
+ }
+
+ return n;
+ }
+}
+
+static PyHamtNode *
+hamt_node_bitmap_assoc(PyHamtNode_Bitmap *self,
+ uint32_t shift, int32_t hash,
+ PyObject *key, PyObject *val, int* added_leaf)
+{
+ /* assoc operation for bitmap nodes.
+
+ Return: a new node, or self if key/val already is in the
+ collection.
+
+ 'added_leaf' is later used in '_PyHamt_Assoc' to determine if
+ `hamt.set(key, val)` increased the size of the collection.
+ */
+
+ uint32_t bit = hamt_bitpos(hash, shift);
+ uint32_t idx = hamt_bitindex(self->b_bitmap, bit);
+
+ /* Bitmap node layout:
+
+ +------+------+------+------+ --- +------+------+
+ | key1 | val1 | key2 | val2 | ... | keyN | valN |
+ +------+------+------+------+ --- +------+------+
+ where `N < Py_SIZE(node)`.
+
+ The `node->b_bitmap` field is a bitmap. For a given
+ `(shift, hash)` pair we can determine:
+
+ - If this node has the corresponding key/val slots.
+ - The index of key/val slots.
+ */
+
+ if (self->b_bitmap & bit) {
+ /* The key is set in this node */
+
+ uint32_t key_idx = 2 * idx;
+ uint32_t val_idx = key_idx + 1;
+
+ assert(val_idx < Py_SIZE(self));
+
+ PyObject *key_or_null = self->b_array[key_idx];
+ PyObject *val_or_node = self->b_array[val_idx];
+
+ if (key_or_null == NULL) {
+ /* key is NULL. This means that we have a few keys
+ that have the same (hash, shift) pair. */
+
+ assert(val_or_node != NULL);
+
+ PyHamtNode *sub_node = hamt_node_assoc(
+ (PyHamtNode *)val_or_node,
+ shift + 5, hash, key, val, added_leaf);
+ if (sub_node == NULL) {
+ return NULL;
+ }
+
+ if (val_or_node == (PyObject *)sub_node) {
+ Py_DECREF(sub_node);
+ Py_INCREF(self);
+ return (PyHamtNode *)self;
+ }
+
+ PyHamtNode_Bitmap *ret = hamt_node_bitmap_clone(self);
+ if (ret == NULL) {
+ return NULL;
+ }
+ Py_SETREF(ret->b_array[val_idx], (PyObject*)sub_node);
+ return (PyHamtNode *)ret;
+ }
+
+ assert(key != NULL);
+ /* key is not NULL. This means that we have only one other
+ key in this collection that matches our hash for this shift. */
+
+ int comp_err = PyObject_RichCompareBool(key, key_or_null, Py_EQ);
+ if (comp_err < 0) { /* exception in __eq__ */
+ return NULL;
+ }
+ if (comp_err == 1) { /* key == key_or_null */
+ if (val == val_or_node) {
+ /* we already have the same key/val pair; return self. */
+ Py_INCREF(self);
+ return (PyHamtNode *)self;
+ }
+
+ /* We're setting a new value for the key we had before.
+ Make a new bitmap node with a replaced value, and return it. */
+ PyHamtNode_Bitmap *ret = hamt_node_bitmap_clone(self);
+ if (ret == NULL) {
+ return NULL;
+ }
+ Py_INCREF(val);
+ Py_SETREF(ret->b_array[val_idx], val);
+ return (PyHamtNode *)ret;
+ }
+
+ /* It's a new key, and it has the same index as *one* another key.
+ We have a collision. We need to create a new node which will
+ combine the existing key and the key we're adding.
+
+ `hamt_node_new_bitmap_or_collision` will either create a new
+ Collision node if the keys have identical hashes, or
+ a new Bitmap node.
+ */
+ PyHamtNode *sub_node = hamt_node_new_bitmap_or_collision(
+ shift + 5,
+ key_or_null, val_or_node, /* existing key/val */
+ hash,
+ key, val /* new key/val */
+ );
+ if (sub_node == NULL) {
+ return NULL;
+ }
+
+ PyHamtNode_Bitmap *ret = hamt_node_bitmap_clone(self);
+ if (ret == NULL) {
+ Py_DECREF(sub_node);
+ return NULL;
+ }
+ Py_SETREF(ret->b_array[key_idx], NULL);
+ Py_SETREF(ret->b_array[val_idx], (PyObject *)sub_node);
+
+ *added_leaf = 1;
+ return (PyHamtNode *)ret;
+ }
+ else {
+ /* There was no key before with the same (shift,hash). */
+
+ uint32_t n = hamt_bitcount(self->b_bitmap);
+
+ if (n >= 16) {
+ /* When we have a situation where we want to store more
+ than 16 nodes at one level of the tree, we no longer
+ want to use the Bitmap node with bitmap encoding.
+
+ Instead we start using an Array node, which has
+ simpler (faster) implementation at the expense of
+ having prealocated 32 pointers for its keys/values
+ pairs.
+
+ Small hamt objects (<30 keys) usually don't have any
+ Array nodes at all. Betwen ~30 and ~400 keys hamt
+ objects usually have one Array node, and usually it's
+ a root node.
+ */
+
+ uint32_t jdx = hamt_mask(hash, shift);
+ /* 'jdx' is the index of where the new key should be added
+ in the new Array node we're about to create. */
+
+ PyHamtNode *empty = NULL;
+ PyHamtNode_Array *new_node = NULL;
+ PyHamtNode *res = NULL;
+
+ /* Create a new Array node. */
+ new_node = (PyHamtNode_Array *)hamt_node_array_new(n + 1);
+ if (new_node == NULL) {
+ goto fin;
+ }
+
+ /* Create an empty bitmap node for the next
+ hamt_node_assoc call. */
+ empty = hamt_node_bitmap_new(0);
+ if (empty == NULL) {
+ goto fin;
+ }
+
+ /* Make a new bitmap node for the key/val we're adding.
+ Set that bitmap node to new-array-node[jdx]. */
+ new_node->a_array[jdx] = hamt_node_assoc(
+ empty, shift + 5, hash, key, val, added_leaf);
+ if (new_node->a_array[jdx] == NULL) {
+ goto fin;
+ }
+
+ /* Copy existing key/value pairs from the current Bitmap
+ node to the new Array node we've just created. */
+ Py_ssize_t i, j;
+ for (i = 0, j = 0; i < HAMT_ARRAY_NODE_SIZE; i++) {
+ if (((self->b_bitmap >> i) & 1) != 0) {
+ /* Ensure we don't accidentally override `jdx` element
+ we set few lines above.
+ */
+ assert(new_node->a_array[i] == NULL);
+
+ if (self->b_array[j] == NULL) {
+ new_node->a_array[i] =
+ (PyHamtNode *)self->b_array[j + 1];
+ Py_INCREF(new_node->a_array[i]);
+ }
+ else {
+ int32_t rehash = hamt_hash(self->b_array[j]);
+ if (rehash == -1) {
+ goto fin;
+ }
+
+ new_node->a_array[i] = hamt_node_assoc(
+ empty, shift + 5,
+ rehash,
+ self->b_array[j],
+ self->b_array[j + 1],
+ added_leaf);
+
+ if (new_node->a_array[i] == NULL) {
+ goto fin;
+ }
+ }
+ j += 2;
+ }
+ }
+
+ VALIDATE_ARRAY_NODE(new_node)
+
+ /* That's it! */
+ res = (PyHamtNode *)new_node;
+
+ fin:
+ Py_XDECREF(empty);
+ if (res == NULL) {
+ Py_XDECREF(new_node);
+ }
+ return res;
+ }
+ else {
+ /* We have less than 16 keys at this level; let's just
+ create a new bitmap node out of this node with the
+ new key/val pair added. */
+
+ uint32_t key_idx = 2 * idx;
+ uint32_t val_idx = key_idx + 1;
+ Py_ssize_t i;
+
+ *added_leaf = 1;
+
+ /* Allocate new Bitmap node which can have one more key/val
+ pair in addition to what we have already. */
+ PyHamtNode_Bitmap *new_node =
+ (PyHamtNode_Bitmap *)hamt_node_bitmap_new(2 * (n + 1));
+ if (new_node == NULL) {
+ return NULL;
+ }
+
+ /* Copy all keys/values that will be before the new key/value
+ we are adding. */
+ for (i = 0; i < key_idx; i++) {
+ Py_XINCREF(self->b_array[i]);
+ new_node->b_array[i] = self->b_array[i];
+ }
+
+ /* Set the new key/value to the new Bitmap node. */
+ Py_INCREF(key);
+ new_node->b_array[key_idx] = key;
+ Py_INCREF(val);
+ new_node->b_array[val_idx] = val;
+
+ /* Copy all keys/values that will be after the new key/value
+ we are adding. */
+ for (i = key_idx; i < Py_SIZE(self); i++) {
+ Py_XINCREF(self->b_array[i]);
+ new_node->b_array[i + 2] = self->b_array[i];
+ }
+
+ new_node->b_bitmap = self->b_bitmap | bit;
+ return (PyHamtNode *)new_node;
+ }
+ }
+}
+
+static hamt_without_t
+hamt_node_bitmap_without(PyHamtNode_Bitmap *self,
+ uint32_t shift, int32_t hash,
+ PyObject *key,
+ PyHamtNode **new_node)
+{
+ uint32_t bit = hamt_bitpos(hash, shift);
+ if ((self->b_bitmap & bit) == 0) {
+ return W_NOT_FOUND;
+ }
+
+ uint32_t idx = hamt_bitindex(self->b_bitmap, bit);
+
+ uint32_t key_idx = 2 * idx;
+ uint32_t val_idx = key_idx + 1;
+
+ PyObject *key_or_null = self->b_array[key_idx];
+ PyObject *val_or_node = self->b_array[val_idx];
+
+ if (key_or_null == NULL) {
+ /* key == NULL means that 'value' is another tree node. */
+
+ PyHamtNode *sub_node = NULL;
+
+ hamt_without_t res = hamt_node_without(
+ (PyHamtNode *)val_or_node,
+ shift + 5, hash, key, &sub_node);
+
+ switch (res) {
+ case W_EMPTY:
+ /* It's impossible for us to receive a W_EMPTY here:
+
+ - Array nodes are converted to Bitmap nodes when
+ we delete 16th item from them;
+
+ - Collision nodes are converted to Bitmap when
+ there is one item in them;
+
+ - Bitmap node's without() inlines single-item
+ sub-nodes.
+
+ So in no situation we can have a single-item
+ Bitmap child of another Bitmap node.
+ */
+ Py_UNREACHABLE();
+
+ case W_NEWNODE: {
+ assert(sub_node != NULL);
+
+ if (IS_BITMAP_NODE(sub_node)) {
+ PyHamtNode_Bitmap *sub_tree = (PyHamtNode_Bitmap *)sub_node;
+ if (hamt_node_bitmap_count(sub_tree) == 1 &&
+ sub_tree->b_array[0] != NULL)
+ {
+ /* A bitmap node with one key/value pair. Just
+ merge it into this node.
+
+ Note that we don't inline Bitmap nodes that
+ have a NULL key -- those nodes point to another
+ tree level, and we cannot simply move tree levels
+ up or down.
+ */
+
+ PyHamtNode_Bitmap *clone = hamt_node_bitmap_clone(self);
+ if (clone == NULL) {
+ Py_DECREF(sub_node);
+ return W_ERROR;
+ }
+
+ PyObject *key = sub_tree->b_array[0];
+ PyObject *val = sub_tree->b_array[1];
+
+ Py_INCREF(key);
+ Py_XSETREF(clone->b_array[key_idx], key);
+ Py_INCREF(val);
+ Py_SETREF(clone->b_array[val_idx], val);
+
+ Py_DECREF(sub_tree);
+
+ *new_node = (PyHamtNode *)clone;
+ return W_NEWNODE;
+ }
+ }
+
+#ifdef Py_DEBUG
+ /* Ensure that Collision.without implementation
+ converts to Bitmap nodes itself.
+ */
+ if (IS_COLLISION_NODE(sub_node)) {
+ assert(hamt_node_collision_count(
+ (PyHamtNode_Collision*)sub_node) > 1);
+ }
+#endif
+
+ PyHamtNode_Bitmap *clone = hamt_node_bitmap_clone(self);
+ Py_SETREF(clone->b_array[val_idx],
+ (PyObject *)sub_node); /* borrow */
+
+ *new_node = (PyHamtNode *)clone;
+ return W_NEWNODE;
+ }
+
+ case W_ERROR:
+ case W_NOT_FOUND:
+ assert(sub_node == NULL);
+ return res;
+
+ default:
+ Py_UNREACHABLE();
+ }
+ }
+ else {
+ /* We have a regular key/value pair */
+
+ int cmp = PyObject_RichCompareBool(key_or_null, key, Py_EQ);
+ if (cmp < 0) {
+ return W_ERROR;
+ }
+ if (cmp == 0) {
+ return W_NOT_FOUND;
+ }
+
+ if (hamt_node_bitmap_count(self) == 1) {
+ return W_EMPTY;
+ }
+
+ *new_node = (PyHamtNode *)
+ hamt_node_bitmap_clone_without(self, bit);
+ if (*new_node == NULL) {
+ return W_ERROR;
+ }
+
+ return W_NEWNODE;
+ }
+}
+
+static hamt_find_t
+hamt_node_bitmap_find(PyHamtNode_Bitmap *self,
+ uint32_t shift, int32_t hash,
+ PyObject *key, PyObject **val)
+{
+ /* Lookup a key in a Bitmap node. */
+
+ uint32_t bit = hamt_bitpos(hash, shift);
+ uint32_t idx;
+ uint32_t key_idx;
+ uint32_t val_idx;
+ PyObject *key_or_null;
+ PyObject *val_or_node;
+ int comp_err;
+
+ if ((self->b_bitmap & bit) == 0) {
+ return F_NOT_FOUND;
+ }
+
+ idx = hamt_bitindex(self->b_bitmap, bit);
+ assert(idx >= 0);
+ key_idx = idx * 2;
+ val_idx = key_idx + 1;
+
+ assert(val_idx < Py_SIZE(self));
+
+ key_or_null = self->b_array[key_idx];
+ val_or_node = self->b_array[val_idx];
+
+ if (key_or_null == NULL) {
+ /* There are a few keys that have the same hash at the current shift
+ that match our key. Dispatch the lookup further down the tree. */
+ assert(val_or_node != NULL);
+ return hamt_node_find((PyHamtNode *)val_or_node,
+ shift + 5, hash, key, val);
+ }
+
+ /* We have only one key -- a potential match. Let's compare if the
+ key we are looking at is equal to the key we are looking for. */
+ assert(key != NULL);
+ comp_err = PyObject_RichCompareBool(key, key_or_null, Py_EQ);
+ if (comp_err < 0) { /* exception in __eq__ */
+ return F_ERROR;
+ }
+ if (comp_err == 1) { /* key == key_or_null */
+ *val = val_or_node;
+ return F_FOUND;
+ }
+
+ return F_NOT_FOUND;
+}
+
+static int
+hamt_node_bitmap_traverse(PyHamtNode_Bitmap *self, visitproc visit, void *arg)
+{
+ /* Bitmap's tp_traverse */
+
+ Py_ssize_t i;
+
+ for (i = Py_SIZE(self); --i >= 0; ) {
+ Py_VISIT(self->b_array[i]);
+ }
+
+ return 0;
+}
+
+static void
+hamt_node_bitmap_dealloc(PyHamtNode_Bitmap *self)
+{
+ /* Bitmap's tp_dealloc */
+
+ Py_ssize_t len = Py_SIZE(self);
+ Py_ssize_t i;
+
+ PyObject_GC_UnTrack(self);
+ Py_TRASHCAN_SAFE_BEGIN(self)
+
+ if (len > 0) {
+ i = len;
+ while (--i >= 0) {
+ Py_XDECREF(self->b_array[i]);
+ }
+ }
+
+ Py_TYPE(self)->tp_free((PyObject *)self);
+ Py_TRASHCAN_SAFE_END(self)
+}
+
+#ifdef Py_DEBUG
+static int
+hamt_node_bitmap_dump(PyHamtNode_Bitmap *node,
+ _PyUnicodeWriter *writer, int level)
+{
+ /* Debug build: __dump__() method implementation for Bitmap nodes. */
+
+ Py_ssize_t i;
+ PyObject *tmp1;
+ PyObject *tmp2;
+
+ if (_hamt_dump_ident(writer, level + 1)) {
+ goto error;
+ }
+
+ if (_hamt_dump_format(writer, "BitmapNode(size=%zd count=%zd ",
+ Py_SIZE(node), Py_SIZE(node) / 2))
+ {
+ goto error;
+ }
+
+ tmp1 = PyLong_FromUnsignedLong(node->b_bitmap);
+ if (tmp1 == NULL) {
+ goto error;
+ }
+ tmp2 = _PyLong_Format(tmp1, 2);
+ Py_DECREF(tmp1);
+ if (tmp2 == NULL) {
+ goto error;
+ }
+ if (_hamt_dump_format(writer, "bitmap=%S id=%p):\n", tmp2, node)) {
+ Py_DECREF(tmp2);
+ goto error;
+ }
+ Py_DECREF(tmp2);
+
+ for (i = 0; i < Py_SIZE(node); i += 2) {
+ PyObject *key_or_null = node->b_array[i];
+ PyObject *val_or_node = node->b_array[i + 1];
+
+ if (_hamt_dump_ident(writer, level + 2)) {
+ goto error;
+ }
+
+ if (key_or_null == NULL) {
+ if (_hamt_dump_format(writer, "NULL:\n")) {
+ goto error;
+ }
+
+ if (hamt_node_dump((PyHamtNode *)val_or_node,
+ writer, level + 2))
+ {
+ goto error;
+ }
+ }
+ else {
+ if (_hamt_dump_format(writer, "%R: %R", key_or_null,
+ val_or_node))
+ {
+ goto error;
+ }
+ }
+
+ if (_hamt_dump_format(writer, "\n")) {
+ goto error;
+ }
+ }
+
+ return 0;
+error:
+ return -1;
+}
+#endif /* Py_DEBUG */
+
+
+/////////////////////////////////// Collision Node
+
+
+static PyHamtNode *
+hamt_node_collision_new(int32_t hash, Py_ssize_t size)
+{
+ /* Create a new Collision node. */
+
+ PyHamtNode_Collision *node;
+ Py_ssize_t i;
+
+ assert(size >= 4);
+ assert(size % 2 == 0);
+
+ node = PyObject_GC_NewVar(
+ PyHamtNode_Collision, &_PyHamt_CollisionNode_Type, size);
+ if (node == NULL) {
+ return NULL;
+ }
+
+ for (i = 0; i < size; i++) {
+ node->c_array[i] = NULL;
+ }
+
+ Py_SIZE(node) = size;
+ node->c_hash = hash;
+
+ _PyObject_GC_TRACK(node);
+
+ return (PyHamtNode *)node;
+}
+
+static hamt_find_t
+hamt_node_collision_find_index(PyHamtNode_Collision *self, PyObject *key,
+ Py_ssize_t *idx)
+{
+ /* Lookup `key` in the Collision node `self`. Set the index of the
+ found key to 'idx'. */
+
+ Py_ssize_t i;
+ PyObject *el;
+
+ for (i = 0; i < Py_SIZE(self); i += 2) {
+ el = self->c_array[i];
+
+ assert(el != NULL);
+ int cmp = PyObject_RichCompareBool(key, el, Py_EQ);
+ if (cmp < 0) {
+ return F_ERROR;
+ }
+ if (cmp == 1) {
+ *idx = i;
+ return F_FOUND;
+ }
+ }
+
+ return F_NOT_FOUND;
+}
+
+static PyHamtNode *
+hamt_node_collision_assoc(PyHamtNode_Collision *self,
+ uint32_t shift, int32_t hash,
+ PyObject *key, PyObject *val, int* added_leaf)
+{
+ /* Set a new key to this level (currently a Collision node)
+ of the tree. */
+
+ if (hash == self->c_hash) {
+ /* The hash of the 'key' we are adding matches the hash of
+ other keys in this Collision node. */
+
+ Py_ssize_t key_idx = -1;
+ hamt_find_t found;
+ PyHamtNode_Collision *new_node;
+ Py_ssize_t i;
+
+ /* Let's try to lookup the new 'key', maybe we already have it. */
+ found = hamt_node_collision_find_index(self, key, &key_idx);
+ switch (found) {
+ case F_ERROR:
+ /* Exception. */
+ return NULL;
+
+ case F_NOT_FOUND:
+ /* This is a totally new key. Clone the current node,
+ add a new key/value to the cloned node. */
+
+ new_node = (PyHamtNode_Collision *)hamt_node_collision_new(
+ self->c_hash, Py_SIZE(self) + 2);
+ if (new_node == NULL) {
+ return NULL;
+ }
+
+ for (i = 0; i < Py_SIZE(self); i++) {
+ Py_INCREF(self->c_array[i]);
+ new_node->c_array[i] = self->c_array[i];
+ }
+
+ Py_INCREF(key);
+ new_node->c_array[i] = key;
+ Py_INCREF(val);
+ new_node->c_array[i + 1] = val;
+
+ *added_leaf = 1;
+ return (PyHamtNode *)new_node;
+
+ case F_FOUND:
+ /* There's a key which is equal to the key we are adding. */
+
+ assert(key_idx >= 0);
+ assert(key_idx < Py_SIZE(self));
+ Py_ssize_t val_idx = key_idx + 1;
+
+ if (self->c_array[val_idx] == val) {
+ /* We're setting a key/value pair that's already set. */
+ Py_INCREF(self);
+ return (PyHamtNode *)self;
+ }
+
+ /* We need to replace old value for the key
+ with a new value. Create a new Collision node.*/
+ new_node = (PyHamtNode_Collision *)hamt_node_collision_new(
+ self->c_hash, Py_SIZE(self));
+ if (new_node == NULL) {
+ return NULL;
+ }
+
+ /* Copy all elements of the old node to the new one. */
+ for (i = 0; i < Py_SIZE(self); i++) {
+ Py_INCREF(self->c_array[i]);
+ new_node->c_array[i] = self->c_array[i];
+ }
+
+ /* Replace the old value with the new value for the our key. */
+ Py_DECREF(new_node->c_array[val_idx]);
+ Py_INCREF(val);
+ new_node->c_array[val_idx] = val;
+
+ return (PyHamtNode *)new_node;
+
+ default:
+ Py_UNREACHABLE();
+ }
+ }
+ else {
+ /* The hash of the new key is different from the hash that
+ all keys of this Collision node have.
+
+ Create a Bitmap node inplace with two children:
+ key/value pair that we're adding, and the Collision node
+ we're replacing on this tree level.
+ */
+
+ PyHamtNode_Bitmap *new_node;
+ PyHamtNode *assoc_res;
+
+ new_node = (PyHamtNode_Bitmap *)hamt_node_bitmap_new(2);
+ if (new_node == NULL) {
+ return NULL;
+ }
+ new_node->b_bitmap = hamt_bitpos(self->c_hash, shift);
+ Py_INCREF(self);
+ new_node->b_array[1] = (PyObject*) self;
+
+ assoc_res = hamt_node_bitmap_assoc(
+ new_node, shift, hash, key, val, added_leaf);
+ Py_DECREF(new_node);
+ return assoc_res;
+ }
+}
+
+static inline Py_ssize_t
+hamt_node_collision_count(PyHamtNode_Collision *node)
+{
+ return Py_SIZE(node) / 2;
+}
+
+static hamt_without_t
+hamt_node_collision_without(PyHamtNode_Collision *self,
+ uint32_t shift, int32_t hash,
+ PyObject *key,
+ PyHamtNode **new_node)
+{
+ if (hash != self->c_hash) {
+ return W_NOT_FOUND;
+ }
+
+ Py_ssize_t key_idx = -1;
+ hamt_find_t found = hamt_node_collision_find_index(self, key, &key_idx);
+
+ switch (found) {
+ case F_ERROR:
+ return W_ERROR;
+
+ case F_NOT_FOUND:
+ return W_NOT_FOUND;
+
+ case F_FOUND:
+ assert(key_idx >= 0);
+ assert(key_idx < Py_SIZE(self));
+
+ Py_ssize_t new_count = hamt_node_collision_count(self) - 1;
+
+ if (new_count == 0) {
+ /* The node has only one key/value pair and it's for the
+ key we're trying to delete. So a new node will be empty
+ after the removal.
+ */
+ return W_EMPTY;
+ }
+
+ if (new_count == 1) {
+ /* The node has two keys, and after deletion the
+ new Collision node would have one. Collision nodes
+ with one key shouldn't exist, co convert it to a
+ Bitmap node.
+ */
+ PyHamtNode_Bitmap *node = (PyHamtNode_Bitmap *)
+ hamt_node_bitmap_new(2);
+ if (node == NULL) {
+ return W_ERROR;
+ }
+
+ if (key_idx == 0) {
+ Py_INCREF(self->c_array[2]);
+ node->b_array[0] = self->c_array[2];
+ Py_INCREF(self->c_array[3]);
+ node->b_array[1] = self->c_array[3];
+ }
+ else {
+ assert(key_idx == 2);
+ Py_INCREF(self->c_array[0]);
+ node->b_array[0] = self->c_array[0];
+ Py_INCREF(self->c_array[1]);
+ node->b_array[1] = self->c_array[1];
+ }
+
+ node->b_bitmap = hamt_bitpos(hash, shift);
+
+ *new_node = (PyHamtNode *)node;
+ return W_NEWNODE;
+ }
+
+ /* Allocate a new Collision node with capacity for one
+ less key/value pair */
+ PyHamtNode_Collision *new = (PyHamtNode_Collision *)
+ hamt_node_collision_new(
+ self->c_hash, Py_SIZE(self) - 2);
+
+ /* Copy all other keys from `self` to `new` */
+ Py_ssize_t i;
+ for (i = 0; i < key_idx; i++) {
+ Py_INCREF(self->c_array[i]);
+ new->c_array[i] = self->c_array[i];
+ }
+ for (i = key_idx + 2; i < Py_SIZE(self); i++) {
+ Py_INCREF(self->c_array[i]);
+ new->c_array[i - 2] = self->c_array[i];
+ }
+
+ *new_node = (PyHamtNode*)new;
+ return W_NEWNODE;
+
+ default:
+ Py_UNREACHABLE();
+ }
+}
+
+static hamt_find_t
+hamt_node_collision_find(PyHamtNode_Collision *self,
+ uint32_t shift, int32_t hash,
+ PyObject *key, PyObject **val)
+{
+ /* Lookup `key` in the Collision node `self`. Set the value
+ for the found key to 'val'. */
+
+ Py_ssize_t idx = -1;
+ hamt_find_t res;
+
+ res = hamt_node_collision_find_index(self, key, &idx);
+ if (res == F_ERROR || res == F_NOT_FOUND) {
+ return res;
+ }
+
+ assert(idx >= 0);
+ assert(idx + 1 < Py_SIZE(self));
+
+ *val = self->c_array[idx + 1];
+ assert(*val != NULL);
+
+ return F_FOUND;
+}
+
+
+static int
+hamt_node_collision_traverse(PyHamtNode_Collision *self,
+ visitproc visit, void *arg)
+{
+ /* Collision's tp_traverse */
+
+ Py_ssize_t i;
+
+ for (i = Py_SIZE(self); --i >= 0; ) {
+ Py_VISIT(self->c_array[i]);
+ }
+
+ return 0;
+}
+
+static void
+hamt_node_collision_dealloc(PyHamtNode_Collision *self)
+{
+ /* Collision's tp_dealloc */
+
+ Py_ssize_t len = Py_SIZE(self);
+
+ PyObject_GC_UnTrack(self);
+ Py_TRASHCAN_SAFE_BEGIN(self)
+
+ if (len > 0) {
+
+ while (--len >= 0) {
+ Py_XDECREF(self->c_array[len]);
+ }
+ }
+
+ Py_TYPE(self)->tp_free((PyObject *)self);
+ Py_TRASHCAN_SAFE_END(self)
+}
+
+#ifdef Py_DEBUG
+static int
+hamt_node_collision_dump(PyHamtNode_Collision *node,
+ _PyUnicodeWriter *writer, int level)
+{
+ /* Debug build: __dump__() method implementation for Collision nodes. */
+
+ Py_ssize_t i;
+
+ if (_hamt_dump_ident(writer, level + 1)) {
+ goto error;
+ }
+
+ if (_hamt_dump_format(writer, "CollisionNode(size=%zd id=%p):\n",
+ Py_SIZE(node), node))
+ {
+ goto error;
+ }
+
+ for (i = 0; i < Py_SIZE(node); i += 2) {
+ PyObject *key = node->c_array[i];
+ PyObject *val = node->c_array[i + 1];
+
+ if (_hamt_dump_ident(writer, level + 2)) {
+ goto error;
+ }
+
+ if (_hamt_dump_format(writer, "%R: %R\n", key, val)) {
+ goto error;
+ }
+ }
+
+ return 0;
+error:
+ return -1;
+}
+#endif /* Py_DEBUG */
+
+
+/////////////////////////////////// Array Node
+
+
+static PyHamtNode *
+hamt_node_array_new(Py_ssize_t count)
+{
+ Py_ssize_t i;
+
+ PyHamtNode_Array *node = PyObject_GC_New(
+ PyHamtNode_Array, &_PyHamt_ArrayNode_Type);
+ if (node == NULL) {
+ return NULL;
+ }
+
+ for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) {
+ node->a_array[i] = NULL;
+ }
+
+ node->a_count = count;
+
+ _PyObject_GC_TRACK(node);
+ return (PyHamtNode *)node;
+}
+
+static PyHamtNode_Array *
+hamt_node_array_clone(PyHamtNode_Array *node)
+{
+ PyHamtNode_Array *clone;
+ Py_ssize_t i;
+
+ VALIDATE_ARRAY_NODE(node)
+
+ /* Create a new Array node. */
+ clone = (PyHamtNode_Array *)hamt_node_array_new(node->a_count);
+ if (clone == NULL) {
+ return NULL;
+ }
+
+ /* Copy all elements from the current Array node to the new one. */
+ for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) {
+ Py_XINCREF(node->a_array[i]);
+ clone->a_array[i] = node->a_array[i];
+ }
+
+ VALIDATE_ARRAY_NODE(clone)
+ return clone;
+}
+
+static PyHamtNode *
+hamt_node_array_assoc(PyHamtNode_Array *self,
+ uint32_t shift, int32_t hash,
+ PyObject *key, PyObject *val, int* added_leaf)
+{
+ /* Set a new key to this level (currently a Collision node)
+ of the tree.
+
+ Array nodes don't store values, they can only point to
+ other nodes. They are simple arrays of 32 BaseNode pointers/
+ */
+
+ uint32_t idx = hamt_mask(hash, shift);
+ PyHamtNode *node = self->a_array[idx];
+ PyHamtNode *child_node;
+ PyHamtNode_Array *new_node;
+ Py_ssize_t i;
+
+ if (node == NULL) {
+ /* There's no child node for the given hash. Create a new
+ Bitmap node for this key. */
+
+ PyHamtNode_Bitmap *empty = NULL;
+
+ /* Get an empty Bitmap node to work with. */
+ empty = (PyHamtNode_Bitmap *)hamt_node_bitmap_new(0);
+ if (empty == NULL) {
+ return NULL;
+ }
+
+ /* Set key/val to the newly created empty Bitmap, thus
+ creating a new Bitmap node with our key/value pair. */
+ child_node = hamt_node_bitmap_assoc(
+ empty,
+ shift + 5, hash, key, val, added_leaf);
+ Py_DECREF(empty);
+ if (child_node == NULL) {
+ return NULL;
+ }
+
+ /* Create a new Array node. */
+ new_node = (PyHamtNode_Array *)hamt_node_array_new(self->a_count + 1);
+ if (new_node == NULL) {
+ Py_DECREF(child_node);
+ return NULL;
+ }
+
+ /* Copy all elements from the current Array node to the
+ new one. */
+ for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) {
+ Py_XINCREF(self->a_array[i]);
+ new_node->a_array[i] = self->a_array[i];
+ }
+
+ assert(new_node->a_array[idx] == NULL);
+ new_node->a_array[idx] = child_node; /* borrow */
+ VALIDATE_ARRAY_NODE(new_node)
+ }
+ else {
+ /* There's a child node for the given hash.
+ Set the key to it./ */
+ child_node = hamt_node_assoc(
+ node, shift + 5, hash, key, val, added_leaf);
+ if (child_node == (PyHamtNode *)self) {
+ Py_DECREF(child_node);
+ return (PyHamtNode *)self;
+ }
+
+ new_node = hamt_node_array_clone(self);
+ if (new_node == NULL) {
+ Py_DECREF(child_node);
+ return NULL;
+ }
+
+ Py_SETREF(new_node->a_array[idx], child_node); /* borrow */
+ VALIDATE_ARRAY_NODE(new_node)
+ }
+
+ return (PyHamtNode *)new_node;
+}
+
+static hamt_without_t
+hamt_node_array_without(PyHamtNode_Array *self,
+ uint32_t shift, int32_t hash,
+ PyObject *key,
+ PyHamtNode **new_node)
+{
+ uint32_t idx = hamt_mask(hash, shift);
+ PyHamtNode *node = self->a_array[idx];
+
+ if (node == NULL) {
+ return W_NOT_FOUND;
+ }
+
+ PyHamtNode *sub_node = NULL;
+ hamt_without_t res = hamt_node_without(
+ (PyHamtNode *)node,
+ shift + 5, hash, key, &sub_node);
+
+ switch (res) {
+ case W_NOT_FOUND:
+ case W_ERROR:
+ assert(sub_node == NULL);
+ return res;
+
+ case W_NEWNODE: {
+ /* We need to replace a node at the `idx` index.
+ Clone this node and replace.
+ */
+ assert(sub_node != NULL);
+
+ PyHamtNode_Array *clone = hamt_node_array_clone(self);
+ if (clone == NULL) {
+ Py_DECREF(sub_node);
+ return W_ERROR;
+ }
+
+ Py_SETREF(clone->a_array[idx], sub_node); /* borrow */
+ *new_node = (PyHamtNode*)clone; /* borrow */
+ return W_NEWNODE;
+ }
+
+ case W_EMPTY: {
+ assert(sub_node == NULL);
+ /* We need to remove a node at the `idx` index.
+ Calculate the size of the replacement Array node.
+ */
+ Py_ssize_t new_count = self->a_count - 1;
+
+ if (new_count == 0) {
+ return W_EMPTY;
+ }
+
+ if (new_count >= 16) {
+ /* We convert Bitmap nodes to Array nodes, when a
+ Bitmap node needs to store more than 15 key/value
+ pairs. So we will create a new Array node if we
+ the number of key/values after deletion is still
+ greater than 15.
+ */
+
+ PyHamtNode_Array *new = hamt_node_array_clone(self);
+ if (new == NULL) {
+ return W_ERROR;
+ }
+ new->a_count = new_count;
+ Py_CLEAR(new->a_array[idx]);
+
+ *new_node = (PyHamtNode*)new; /* borrow */
+ return W_NEWNODE;
+ }
+
+ /* New Array node would have less than 16 key/value
+ pairs. We need to create a replacement Bitmap node. */
+
+ Py_ssize_t bitmap_size = new_count * 2;
+ uint32_t bitmap = 0;
+
+ PyHamtNode_Bitmap *new = (PyHamtNode_Bitmap *)
+ hamt_node_bitmap_new(bitmap_size);
+ if (new == NULL) {
+ return W_ERROR;
+ }
+
+ Py_ssize_t new_i = 0;
+ for (uint32_t i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) {
+ if (i == idx) {
+ /* Skip the node we are deleting. */
+ continue;
+ }
+
+ PyHamtNode *node = self->a_array[i];
+ if (node == NULL) {
+ /* Skip any missing nodes. */
+ continue;
+ }
+
+ bitmap |= 1 << i;
+
+ if (IS_BITMAP_NODE(node)) {
+ PyHamtNode_Bitmap *child = (PyHamtNode_Bitmap *)node;
+
+ if (hamt_node_bitmap_count(child) == 1 &&
+ child->b_array[0] != NULL)
+ {
+ /* node is a Bitmap with one key/value pair, just
+ merge it into the new Bitmap node we're building.
+
+ Note that we don't inline Bitmap nodes that
+ have a NULL key -- those nodes point to another
+ tree level, and we cannot simply move tree levels
+ up or down.
+ */
+ PyObject *key = child->b_array[0];
+ PyObject *val = child->b_array[1];
+
+ Py_INCREF(key);
+ new->b_array[new_i] = key;
+ Py_INCREF(val);
+ new->b_array[new_i + 1] = val;
+ }
+ else {
+ new->b_array[new_i] = NULL;
+ Py_INCREF(node);
+ new->b_array[new_i + 1] = (PyObject*)node;
+ }
+ }
+ else {
+
+#ifdef Py_DEBUG
+ if (IS_COLLISION_NODE(node)) {
+ Py_ssize_t child_count = hamt_node_collision_count(
+ (PyHamtNode_Collision*)node);
+ assert(child_count > 1);
+ }
+ else if (IS_ARRAY_NODE(node)) {
+ assert(((PyHamtNode_Array*)node)->a_count >= 16);
+ }
+#endif
+
+ /* Just copy the node into our new Bitmap */
+ new->b_array[new_i] = NULL;
+ Py_INCREF(node);
+ new->b_array[new_i + 1] = (PyObject*)node;
+ }
+
+ new_i += 2;
+ }
+
+ new->b_bitmap = bitmap;
+ *new_node = (PyHamtNode*)new; /* borrow */
+ return W_NEWNODE;
+ }
+
+ default:
+ Py_UNREACHABLE();
+ }
+}
+
+static hamt_find_t
+hamt_node_array_find(PyHamtNode_Array *self,
+ uint32_t shift, int32_t hash,
+ PyObject *key, PyObject **val)
+{
+ /* Lookup `key` in the Array node `self`. Set the value
+ for the found key to 'val'. */
+
+ uint32_t idx = hamt_mask(hash, shift);
+ PyHamtNode *node;
+
+ node = self->a_array[idx];
+ if (node == NULL) {
+ return F_NOT_FOUND;
+ }
+
+ /* Dispatch to the generic hamt_node_find */
+ return hamt_node_find(node, shift + 5, hash, key, val);
+}
+
+static int
+hamt_node_array_traverse(PyHamtNode_Array *self,
+ visitproc visit, void *arg)
+{
+ /* Array's tp_traverse */
+
+ Py_ssize_t i;
+
+ for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) {
+ Py_VISIT(self->a_array[i]);
+ }
+
+ return 0;
+}
+
+static void
+hamt_node_array_dealloc(PyHamtNode_Array *self)
+{
+ /* Array's tp_dealloc */
+
+ Py_ssize_t i;
+
+ PyObject_GC_UnTrack(self);
+ Py_TRASHCAN_SAFE_BEGIN(self)
+
+ for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) {
+ Py_XDECREF(self->a_array[i]);
+ }
+
+ Py_TYPE(self)->tp_free((PyObject *)self);
+ Py_TRASHCAN_SAFE_END(self)
+}
+
+#ifdef Py_DEBUG
+static int
+hamt_node_array_dump(PyHamtNode_Array *node,
+ _PyUnicodeWriter *writer, int level)
+{
+ /* Debug build: __dump__() method implementation for Array nodes. */
+
+ Py_ssize_t i;
+
+ if (_hamt_dump_ident(writer, level + 1)) {
+ goto error;
+ }
+
+ if (_hamt_dump_format(writer, "ArrayNode(id=%p):\n", node)) {
+ goto error;
+ }
+
+ for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) {
+ if (node->a_array[i] == NULL) {
+ continue;
+ }
+
+ if (_hamt_dump_ident(writer, level + 2)) {
+ goto error;
+ }
+
+ if (_hamt_dump_format(writer, "%d::\n", i)) {
+ goto error;
+ }
+
+ if (hamt_node_dump(node->a_array[i], writer, level + 1)) {
+ goto error;
+ }
+
+ if (_hamt_dump_format(writer, "\n")) {
+ goto error;
+ }
+ }
+
+ return 0;
+error:
+ return -1;
+}
+#endif /* Py_DEBUG */
+
+
+/////////////////////////////////// Node Dispatch
+
+
+static PyHamtNode *
+hamt_node_assoc(PyHamtNode *node,
+ uint32_t shift, int32_t hash,
+ PyObject *key, PyObject *val, int* added_leaf)
+{
+ /* Set key/value to the 'node' starting with the given shift/hash.
+ Return a new node, or the same node if key/value already
+ set.
+
+ added_leaf will be set to 1 if key/value wasn't in the
+ tree before.
+
+ This method automatically dispatches to the suitable
+ hamt_node_{nodetype}_assoc method.
+ */
+
+ if (IS_BITMAP_NODE(node)) {
+ return hamt_node_bitmap_assoc(
+ (PyHamtNode_Bitmap *)node,
+ shift, hash, key, val, added_leaf);
+ }
+ else if (IS_ARRAY_NODE(node)) {
+ return hamt_node_array_assoc(
+ (PyHamtNode_Array *)node,
+ shift, hash, key, val, added_leaf);
+ }
+ else {
+ assert(IS_COLLISION_NODE(node));
+ return hamt_node_collision_assoc(
+ (PyHamtNode_Collision *)node,
+ shift, hash, key, val, added_leaf);
+ }
+}
+
+static hamt_without_t
+hamt_node_without(PyHamtNode *node,
+ uint32_t shift, int32_t hash,
+ PyObject *key,
+ PyHamtNode **new_node)
+{
+ if (IS_BITMAP_NODE(node)) {
+ return hamt_node_bitmap_without(
+ (PyHamtNode_Bitmap *)node,
+ shift, hash, key,
+ new_node);
+ }
+ else if (IS_ARRAY_NODE(node)) {
+ return hamt_node_array_without(
+ (PyHamtNode_Array *)node,
+ shift, hash, key,
+ new_node);
+ }
+ else {
+ assert(IS_COLLISION_NODE(node));
+ return hamt_node_collision_without(
+ (PyHamtNode_Collision *)node,
+ shift, hash, key,
+ new_node);
+ }
+}
+
+static hamt_find_t
+hamt_node_find(PyHamtNode *node,
+ uint32_t shift, int32_t hash,
+ PyObject *key, PyObject **val)
+{
+ /* Find the key in the node starting with the given shift/hash.
+
+ If a value is found, the result will be set to F_FOUND, and
+ *val will point to the found value object.
+
+ If a value wasn't found, the result will be set to F_NOT_FOUND.
+
+ If an exception occurs during the call, the result will be F_ERROR.
+
+ This method automatically dispatches to the suitable
+ hamt_node_{nodetype}_find method.
+ */
+
+ if (IS_BITMAP_NODE(node)) {
+ return hamt_node_bitmap_find(
+ (PyHamtNode_Bitmap *)node,
+ shift, hash, key, val);
+
+ }
+ else if (IS_ARRAY_NODE(node)) {
+ return hamt_node_array_find(
+ (PyHamtNode_Array *)node,
+ shift, hash, key, val);
+ }
+ else {
+ assert(IS_COLLISION_NODE(node));
+ return hamt_node_collision_find(
+ (PyHamtNode_Collision *)node,
+ shift, hash, key, val);
+ }
+}
+
+#ifdef Py_DEBUG
+static int
+hamt_node_dump(PyHamtNode *node,
+ _PyUnicodeWriter *writer, int level)
+{
+ /* Debug build: __dump__() method implementation for a node.
+
+ This method automatically dispatches to the suitable
+ hamt_node_{nodetype})_dump method.
+ */
+
+ if (IS_BITMAP_NODE(node)) {
+ return hamt_node_bitmap_dump(
+ (PyHamtNode_Bitmap *)node, writer, level);
+ }
+ else if (IS_ARRAY_NODE(node)) {
+ return hamt_node_array_dump(
+ (PyHamtNode_Array *)node, writer, level);
+ }
+ else {
+ assert(IS_COLLISION_NODE(node));
+ return hamt_node_collision_dump(
+ (PyHamtNode_Collision *)node, writer, level);
+ }
+}
+#endif /* Py_DEBUG */
+
+
+/////////////////////////////////// Iterators: Machinery
+
+
+static hamt_iter_t
+hamt_iterator_next(PyHamtIteratorState *iter, PyObject **key, PyObject **val);
+
+
+static void
+hamt_iterator_init(PyHamtIteratorState *iter, PyHamtNode *root)
+{
+ for (uint32_t i = 0; i < _Py_HAMT_MAX_TREE_DEPTH; i++) {
+ iter->i_nodes[i] = NULL;
+ iter->i_pos[i] = 0;
+ }
+
+ iter->i_level = 0;
+
+ /* Note: we don't incref/decref nodes in i_nodes. */
+ iter->i_nodes[0] = root;
+}
+
+static hamt_iter_t
+hamt_iterator_bitmap_next(PyHamtIteratorState *iter,
+ PyObject **key, PyObject **val)
+{
+ int8_t level = iter->i_level;
+
+ PyHamtNode_Bitmap *node = (PyHamtNode_Bitmap *)(iter->i_nodes[level]);
+ Py_ssize_t pos = iter->i_pos[level];
+
+ if (pos + 1 >= Py_SIZE(node)) {
+#ifdef Py_DEBUG
+ assert(iter->i_level >= 0);
+ iter->i_nodes[iter->i_level] = NULL;
+#endif
+ iter->i_level--;
+ return hamt_iterator_next(iter, key, val);
+ }
+
+ if (node->b_array[pos] == NULL) {
+ iter->i_pos[level] = pos + 2;
+
+ int8_t next_level = level + 1;
+ assert(next_level < _Py_HAMT_MAX_TREE_DEPTH);
+ iter->i_level = next_level;
+ iter->i_pos[next_level] = 0;
+ iter->i_nodes[next_level] = (PyHamtNode *)
+ node->b_array[pos + 1];
+
+ return hamt_iterator_next(iter, key, val);
+ }
+
+ *key = node->b_array[pos];
+ *val = node->b_array[pos + 1];
+ iter->i_pos[level] = pos + 2;
+ return I_ITEM;
+}
+
+static hamt_iter_t
+hamt_iterator_collision_next(PyHamtIteratorState *iter,
+ PyObject **key, PyObject **val)
+{
+ int8_t level = iter->i_level;
+
+ PyHamtNode_Collision *node = (PyHamtNode_Collision *)(iter->i_nodes[level]);
+ Py_ssize_t pos = iter->i_pos[level];
+
+ if (pos + 1 >= Py_SIZE(node)) {
+#ifdef Py_DEBUG
+ assert(iter->i_level >= 0);
+ iter->i_nodes[iter->i_level] = NULL;
+#endif
+ iter->i_level--;
+ return hamt_iterator_next(iter, key, val);
+ }
+
+ *key = node->c_array[pos];
+ *val = node->c_array[pos + 1];
+ iter->i_pos[level] = pos + 2;
+ return I_ITEM;
+}
+
+static hamt_iter_t
+hamt_iterator_array_next(PyHamtIteratorState *iter,
+ PyObject **key, PyObject **val)
+{
+ int8_t level = iter->i_level;
+
+ PyHamtNode_Array *node = (PyHamtNode_Array *)(iter->i_nodes[level]);
+ Py_ssize_t pos = iter->i_pos[level];
+
+ if (pos >= HAMT_ARRAY_NODE_SIZE) {
+#ifdef Py_DEBUG
+ assert(iter->i_level >= 0);
+ iter->i_nodes[iter->i_level] = NULL;
+#endif
+ iter->i_level--;
+ return hamt_iterator_next(iter, key, val);
+ }
+
+ for (Py_ssize_t i = pos; i < HAMT_ARRAY_NODE_SIZE; i++) {
+ if (node->a_array[i] != NULL) {
+ iter->i_pos[level] = i + 1;
+
+ int8_t next_level = level + 1;
+ assert(next_level < _Py_HAMT_MAX_TREE_DEPTH);
+ iter->i_pos[next_level] = 0;
+ iter->i_nodes[next_level] = node->a_array[i];
+ iter->i_level = next_level;
+
+ return hamt_iterator_next(iter, key, val);
+ }
+ }
+
+#ifdef Py_DEBUG
+ assert(iter->i_level >= 0);
+ iter->i_nodes[iter->i_level] = NULL;
+#endif
+
+ iter->i_level--;
+ return hamt_iterator_next(iter, key, val);
+}
+
+static hamt_iter_t
+hamt_iterator_next(PyHamtIteratorState *iter, PyObject **key, PyObject **val)
+{
+ if (iter->i_level < 0) {
+ return I_END;
+ }
+
+ assert(iter->i_level < _Py_HAMT_MAX_TREE_DEPTH);
+
+ PyHamtNode *current = iter->i_nodes[iter->i_level];
+
+ if (IS_BITMAP_NODE(current)) {
+ return hamt_iterator_bitmap_next(iter, key, val);
+ }
+ else if (IS_ARRAY_NODE(current)) {
+ return hamt_iterator_array_next(iter, key, val);
+ }
+ else {
+ assert(IS_COLLISION_NODE(current));
+ return hamt_iterator_collision_next(iter, key, val);
+ }
+}
+
+
+/////////////////////////////////// HAMT high-level functions
+
+
+PyHamtObject *
+_PyHamt_Assoc(PyHamtObject *o, PyObject *key, PyObject *val)
+{
+ int32_t key_hash;
+ int added_leaf = 0;
+ PyHamtNode *new_root;
+ PyHamtObject *new_o;
+
+ key_hash = hamt_hash(key);
+ if (key_hash == -1) {
+ return NULL;
+ }
+
+ new_root = hamt_node_assoc(
+ (PyHamtNode *)(o->h_root),
+ 0, key_hash, key, val, &added_leaf);
+ if (new_root == NULL) {
+ return NULL;
+ }
+
+ if (new_root == o->h_root) {
+ Py_DECREF(new_root);
+ Py_INCREF(o);
+ return o;
+ }
+
+ new_o = hamt_alloc();
+ if (new_o == NULL) {
+ Py_DECREF(new_root);
+ return NULL;
+ }
+
+ new_o->h_root = new_root; /* borrow */
+ new_o->h_count = added_leaf ? o->h_count + 1 : o->h_count;
+
+ return new_o;
+}
+
+PyHamtObject *
+_PyHamt_Without(PyHamtObject *o, PyObject *key)
+{
+ int32_t key_hash = hamt_hash(key);
+ if (key_hash == -1) {
+ return NULL;
+ }
+
+ PyHamtNode *new_root;
+
+ hamt_without_t res = hamt_node_without(
+ (PyHamtNode *)(o->h_root),
+ 0, key_hash, key,
+ &new_root);
+
+ switch (res) {
+ case W_ERROR:
+ return NULL;
+ case W_EMPTY:
+ return _PyHamt_New();
+ case W_NOT_FOUND:
+ Py_INCREF(o);
+ return o;
+ case W_NEWNODE: {
+ PyHamtObject *new_o = hamt_alloc();
+ if (new_o == NULL) {
+ Py_DECREF(new_root);
+ return NULL;
+ }
+
+ new_o->h_root = new_root; /* borrow */
+ new_o->h_count = o->h_count - 1;
+ assert(new_o->h_count >= 0);
+ return new_o;
+ }
+ default:
+ Py_UNREACHABLE();
+ }
+}
+
+static hamt_find_t
+hamt_find(PyHamtObject *o, PyObject *key, PyObject **val)
+{
+ if (o->h_count == 0) {
+ return F_NOT_FOUND;
+ }
+
+ int32_t key_hash = hamt_hash(key);
+ if (key_hash == -1) {
+ return F_ERROR;
+ }
+
+ return hamt_node_find(o->h_root, 0, key_hash, key, val);
+}
+
+
+int
+_PyHamt_Find(PyHamtObject *o, PyObject *key, PyObject **val)
+{
+ hamt_find_t res = hamt_find(o, key, val);
+ switch (res) {
+ case F_ERROR:
+ return -1;
+ case F_NOT_FOUND:
+ return 0;
+ case F_FOUND:
+ return 1;
+ default:
+ Py_UNREACHABLE();
+ }
+}
+
+
+int
+_PyHamt_Eq(PyHamtObject *v, PyHamtObject *w)
+{
+ if (v == w) {
+ return 1;
+ }
+
+ if (v->h_count != w->h_count) {
+ return 0;
+ }
+
+ PyHamtIteratorState iter;
+ hamt_iter_t iter_res;
+ hamt_find_t find_res;
+ PyObject *v_key;
+ PyObject *v_val;
+ PyObject *w_val;
+
+ hamt_iterator_init(&iter, v->h_root);
+
+ do {
+ iter_res = hamt_iterator_next(&iter, &v_key, &v_val);
+ if (iter_res == I_ITEM) {
+ find_res = hamt_find(w, v_key, &w_val);
+ switch (find_res) {
+ case F_ERROR:
+ return -1;
+
+ case F_NOT_FOUND:
+ return 0;
+
+ case F_FOUND: {
+ int cmp = PyObject_RichCompareBool(v_val, w_val, Py_EQ);
+ if (cmp < 0) {
+ return -1;
+ }
+ if (cmp == 0) {
+ return 0;
+ }
+ }
+ }
+ }
+ } while (iter_res != I_END);
+
+ return 1;
+}
+
+Py_ssize_t
+_PyHamt_Len(PyHamtObject *o)
+{
+ return o->h_count;
+}
+
+static PyHamtObject *
+hamt_alloc(void)
+{
+ PyHamtObject *o;
+ o = PyObject_GC_New(PyHamtObject, &_PyHamt_Type);
+ if (o == NULL) {
+ return NULL;
+ }
+ o->h_weakreflist = NULL;
+ PyObject_GC_Track(o);
+ return o;
+}
+
+PyHamtObject *
+_PyHamt_New(void)
+{
+ if (_empty_hamt != NULL) {
+ /* HAMT is an immutable object so we can easily cache an
+ empty instance. */
+ Py_INCREF(_empty_hamt);
+ return _empty_hamt;
+ }
+
+ PyHamtObject *o = hamt_alloc();
+ if (o == NULL) {
+ return NULL;
+ }
+
+ o->h_root = hamt_node_bitmap_new(0);
+ if (o->h_root == NULL) {
+ Py_DECREF(o);
+ return NULL;
+ }
+
+ o->h_count = 0;
+
+ if (_empty_hamt == NULL) {
+ Py_INCREF(o);
+ _empty_hamt = o;
+ }
+
+ return o;
+}
+
+#ifdef Py_DEBUG
+static PyObject *
+hamt_dump(PyHamtObject *self)
+{
+ _PyUnicodeWriter writer;
+
+ _PyUnicodeWriter_Init(&writer);
+
+ if (_hamt_dump_format(&writer, "HAMT(len=%zd):\n", self->h_count)) {
+ goto error;
+ }
+
+ if (hamt_node_dump(self->h_root, &writer, 0)) {
+ goto error;
+ }
+
+ return _PyUnicodeWriter_Finish(&writer);
+
+error:
+ _PyUnicodeWriter_Dealloc(&writer);
+ return NULL;
+}
+#endif /* Py_DEBUG */
+
+
+/////////////////////////////////// Iterators: Shared Iterator Implementation
+
+
+static int
+hamt_baseiter_tp_clear(PyHamtIterator *it)
+{
+ Py_CLEAR(it->hi_obj);
+ return 0;
+}
+
+static void
+hamt_baseiter_tp_dealloc(PyHamtIterator *it)
+{
+ PyObject_GC_UnTrack(it);
+ (void)hamt_baseiter_tp_clear(it);
+ PyObject_GC_Del(it);
+}
+
+static int
+hamt_baseiter_tp_traverse(PyHamtIterator *it, visitproc visit, void *arg)
+{
+ Py_VISIT(it->hi_obj);
+ return 0;
+}
+
+static PyObject *
+hamt_baseiter_tp_iternext(PyHamtIterator *it)
+{
+ PyObject *key;
+ PyObject *val;
+ hamt_iter_t res = hamt_iterator_next(&it->hi_iter, &key, &val);
+
+ switch (res) {
+ case I_END:
+ PyErr_SetNone(PyExc_StopIteration);
+ return NULL;
+
+ case I_ITEM: {
+ return (*(it->hi_yield))(key, val);
+ }
+
+ default: {
+ Py_UNREACHABLE();
+ }
+ }
+}
+
+static Py_ssize_t
+hamt_baseiter_tp_len(PyHamtIterator *it)
+{
+ return it->hi_obj->h_count;
+}
+
+static PyMappingMethods PyHamtIterator_as_mapping = {
+ (lenfunc)hamt_baseiter_tp_len,
+};
+
+static PyObject *
+hamt_baseiter_new(PyTypeObject *type, binaryfunc yield, PyHamtObject *o)
+{
+ PyHamtIterator *it = PyObject_GC_New(PyHamtIterator, type);
+ if (it == NULL) {
+ return NULL;
+ }
+
+ Py_INCREF(o);
+ it->hi_obj = o;
+ it->hi_yield = yield;
+
+ hamt_iterator_init(&it->hi_iter, o->h_root);
+
+ return (PyObject*)it;
+}
+
+#define ITERATOR_TYPE_SHARED_SLOTS \
+ .tp_basicsize = sizeof(PyHamtIterator), \
+ .tp_itemsize = 0, \
+ .tp_as_mapping = &PyHamtIterator_as_mapping, \
+ .tp_dealloc = (destructor)hamt_baseiter_tp_dealloc, \
+ .tp_getattro = PyObject_GenericGetAttr, \
+ .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, \
+ .tp_traverse = (traverseproc)hamt_baseiter_tp_traverse, \
+ .tp_clear = (inquiry)hamt_baseiter_tp_clear, \
+ .tp_iter = PyObject_SelfIter, \
+ .tp_iternext = (iternextfunc)hamt_baseiter_tp_iternext,
+
+
+/////////////////////////////////// _PyHamtItems_Type
+
+
+PyTypeObject _PyHamtItems_Type = {
+ PyVarObject_HEAD_INIT(NULL, 0)
+ "items",
+ ITERATOR_TYPE_SHARED_SLOTS
+};
+
+static PyObject *
+hamt_iter_yield_items(PyObject *key, PyObject *val)
+{
+ return PyTuple_Pack(2, key, val);
+}
+
+PyObject *
+_PyHamt_NewIterItems(PyHamtObject *o)
+{
+ return hamt_baseiter_new(
+ &_PyHamtItems_Type, hamt_iter_yield_items, o);
+}
+
+
+/////////////////////////////////// _PyHamtKeys_Type
+
+
+PyTypeObject _PyHamtKeys_Type = {
+ PyVarObject_HEAD_INIT(NULL, 0)
+ "keys",
+ ITERATOR_TYPE_SHARED_SLOTS
+};
+
+static PyObject *
+hamt_iter_yield_keys(PyObject *key, PyObject *val)
+{
+ Py_INCREF(key);
+ return key;
+}
+
+PyObject *
+_PyHamt_NewIterKeys(PyHamtObject *o)
+{
+ return hamt_baseiter_new(
+ &_PyHamtKeys_Type, hamt_iter_yield_keys, o);
+}
+
+
+/////////////////////////////////// _PyHamtValues_Type
+
+
+PyTypeObject _PyHamtValues_Type = {
+ PyVarObject_HEAD_INIT(NULL, 0)
+ "values",
+ ITERATOR_TYPE_SHARED_SLOTS
+};
+
+static PyObject *
+hamt_iter_yield_values(PyObject *key, PyObject *val)
+{
+ Py_INCREF(val);
+ return val;
+}
+
+PyObject *
+_PyHamt_NewIterValues(PyHamtObject *o)
+{
+ return hamt_baseiter_new(
+ &_PyHamtValues_Type, hamt_iter_yield_values, o);
+}
+
+
+/////////////////////////////////// _PyHamt_Type
+
+
+#ifdef Py_DEBUG
+static PyObject *
+hamt_dump(PyHamtObject *self);
+#endif
+
+
+static PyObject *
+hamt_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+ return (PyObject*)_PyHamt_New();
+}
+
+static int
+hamt_tp_clear(PyHamtObject *self)
+{
+ Py_CLEAR(self->h_root);
+ return 0;
+}
+
+
+static int
+hamt_tp_traverse(PyHamtObject *self, visitproc visit, void *arg)
+{
+ Py_VISIT(self->h_root);
+ return 0;
+}
+
+static void
+hamt_tp_dealloc(PyHamtObject *self)
+{
+ PyObject_GC_UnTrack(self);
+ if (self->h_weakreflist != NULL) {
+ PyObject_ClearWeakRefs((PyObject*)self);
+ }
+ (void)hamt_tp_clear(self);
+ Py_TYPE(self)->tp_free(self);
+}
+
+
+static PyObject *
+hamt_tp_richcompare(PyObject *v, PyObject *w, int op)
+{
+ if (!PyHamt_Check(v) || !PyHamt_Check(w) || (op != Py_EQ && op != Py_NE)) {
+ Py_RETURN_NOTIMPLEMENTED;
+ }
+
+ int res = _PyHamt_Eq((PyHamtObject *)v, (PyHamtObject *)w);
+ if (res < 0) {
+ return NULL;
+ }
+
+ if (op == Py_NE) {
+ res = !res;
+ }
+
+ if (res) {
+ Py_RETURN_TRUE;
+ }
+ else {
+ Py_RETURN_FALSE;
+ }
+}
+
+static int
+hamt_tp_contains(PyHamtObject *self, PyObject *key)
+{
+ PyObject *val;
+ return _PyHamt_Find(self, key, &val);
+}
+
+static PyObject *
+hamt_tp_subscript(PyHamtObject *self, PyObject *key)
+{
+ PyObject *val;
+ hamt_find_t res = hamt_find(self, key, &val);
+ switch (res) {
+ case F_ERROR:
+ return NULL;
+ case F_FOUND:
+ Py_INCREF(val);
+ return val;
+ case F_NOT_FOUND:
+ PyErr_SetObject(PyExc_KeyError, key);
+ return NULL;
+ default:
+ Py_UNREACHABLE();
+ }
+}
+
+static Py_ssize_t
+hamt_tp_len(PyHamtObject *self)
+{
+ return _PyHamt_Len(self);
+}
+
+static PyObject *
+hamt_tp_iter(PyHamtObject *self)
+{
+ return _PyHamt_NewIterKeys(self);
+}
+
+static PyObject *
+hamt_py_set(PyHamtObject *self, PyObject *args)
+{
+ PyObject *key;
+ PyObject *val;
+
+ if (!PyArg_UnpackTuple(args, "set", 2, 2, &key, &val)) {
+ return NULL;
+ }
+
+ return (PyObject *)_PyHamt_Assoc(self, key, val);
+}
+
+static PyObject *
+hamt_py_get(PyHamtObject *self, PyObject *args)
+{
+ PyObject *key;
+ PyObject *def = NULL;
+
+ if (!PyArg_UnpackTuple(args, "get", 1, 2, &key, &def)) {
+ return NULL;
+ }
+
+ PyObject *val = NULL;
+ hamt_find_t res = hamt_find(self, key, &val);
+ switch (res) {
+ case F_ERROR:
+ return NULL;
+ case F_FOUND:
+ Py_INCREF(val);
+ return val;
+ case F_NOT_FOUND:
+ if (def == NULL) {
+ Py_RETURN_NONE;
+ }
+ Py_INCREF(def);
+ return def;
+ default:
+ Py_UNREACHABLE();
+ }
+}
+
+static PyObject *
+hamt_py_delete(PyHamtObject *self, PyObject *key)
+{
+ return (PyObject *)_PyHamt_Without(self, key);
+}
+
+static PyObject *
+hamt_py_items(PyHamtObject *self, PyObject *args)
+{
+ return _PyHamt_NewIterItems(self);
+}
+
+static PyObject *
+hamt_py_values(PyHamtObject *self, PyObject *args)
+{
+ return _PyHamt_NewIterValues(self);
+}
+
+static PyObject *
+hamt_py_keys(PyHamtObject *self, PyObject *args)
+{
+ return _PyHamt_NewIterKeys(self);
+}
+
+#ifdef Py_DEBUG
+static PyObject *
+hamt_py_dump(PyHamtObject *self, PyObject *args)
+{
+ return hamt_dump(self);
+}
+#endif
+
+
+static PyMethodDef PyHamt_methods[] = {
+ {"set", (PyCFunction)hamt_py_set, METH_VARARGS, NULL},
+ {"get", (PyCFunction)hamt_py_get, METH_VARARGS, NULL},
+ {"delete", (PyCFunction)hamt_py_delete, METH_O, NULL},
+ {"items", (PyCFunction)hamt_py_items, METH_NOARGS, NULL},
+ {"keys", (PyCFunction)hamt_py_keys, METH_NOARGS, NULL},
+ {"values", (PyCFunction)hamt_py_values, METH_NOARGS, NULL},
+#ifdef Py_DEBUG
+ {"__dump__", (PyCFunction)hamt_py_dump, METH_NOARGS, NULL},
+#endif
+ {NULL, NULL}
+};
+
+static PySequenceMethods PyHamt_as_sequence = {
+ 0, /* sq_length */
+ 0, /* sq_concat */
+ 0, /* sq_repeat */
+ 0, /* sq_item */
+ 0, /* sq_slice */
+ 0, /* sq_ass_item */
+ 0, /* sq_ass_slice */
+ (objobjproc)hamt_tp_contains, /* sq_contains */
+ 0, /* sq_inplace_concat */
+ 0, /* sq_inplace_repeat */
+};
+
+static PyMappingMethods PyHamt_as_mapping = {
+ (lenfunc)hamt_tp_len, /* mp_length */
+ (binaryfunc)hamt_tp_subscript, /* mp_subscript */
+};
+
+PyTypeObject _PyHamt_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "hamt",
+ sizeof(PyHamtObject),
+ .tp_methods = PyHamt_methods,
+ .tp_as_mapping = &PyHamt_as_mapping,
+ .tp_as_sequence = &PyHamt_as_sequence,
+ .tp_iter = (getiterfunc)hamt_tp_iter,
+ .tp_dealloc = (destructor)hamt_tp_dealloc,
+ .tp_getattro = PyObject_GenericGetAttr,
+ .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
+ .tp_richcompare = hamt_tp_richcompare,
+ .tp_traverse = (traverseproc)hamt_tp_traverse,
+ .tp_clear = (inquiry)hamt_tp_clear,
+ .tp_new = hamt_tp_new,
+ .tp_weaklistoffset = offsetof(PyHamtObject, h_weakreflist),
+ .tp_hash = PyObject_HashNotImplemented,
+};
+
+
+/////////////////////////////////// Tree Node Types
+
+
+PyTypeObject _PyHamt_ArrayNode_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "hamt_array_node",
+ sizeof(PyHamtNode_Array),
+ 0,
+ .tp_dealloc = (destructor)hamt_node_array_dealloc,
+ .tp_getattro = PyObject_GenericGetAttr,
+ .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
+ .tp_traverse = (traverseproc)hamt_node_array_traverse,
+ .tp_free = PyObject_GC_Del,
+ .tp_hash = PyObject_HashNotImplemented,
+};
+
+PyTypeObject _PyHamt_BitmapNode_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "hamt_bitmap_node",
+ sizeof(PyHamtNode_Bitmap) - sizeof(PyObject *),
+ sizeof(PyObject *),
+ .tp_dealloc = (destructor)hamt_node_bitmap_dealloc,
+ .tp_getattro = PyObject_GenericGetAttr,
+ .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
+ .tp_traverse = (traverseproc)hamt_node_bitmap_traverse,
+ .tp_free = PyObject_GC_Del,
+ .tp_hash = PyObject_HashNotImplemented,
+};
+
+PyTypeObject _PyHamt_CollisionNode_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "hamt_collision_node",
+ sizeof(PyHamtNode_Collision) - sizeof(PyObject *),
+ sizeof(PyObject *),
+ .tp_dealloc = (destructor)hamt_node_collision_dealloc,
+ .tp_getattro = PyObject_GenericGetAttr,
+ .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
+ .tp_traverse = (traverseproc)hamt_node_collision_traverse,
+ .tp_free = PyObject_GC_Del,
+ .tp_hash = PyObject_HashNotImplemented,
+};
+
+
+int
+_PyHamt_Init(void)
+{
+ if ((PyType_Ready(&_PyHamt_Type) < 0) ||
+ (PyType_Ready(&_PyHamt_ArrayNode_Type) < 0) ||
+ (PyType_Ready(&_PyHamt_BitmapNode_Type) < 0) ||
+ (PyType_Ready(&_PyHamt_CollisionNode_Type) < 0) ||
+ (PyType_Ready(&_PyHamtKeys_Type) < 0) ||
+ (PyType_Ready(&_PyHamtValues_Type) < 0) ||
+ (PyType_Ready(&_PyHamtItems_Type) < 0))
+ {
+ return 0;
+ }
+
+ return 1;
+}
+
+void
+_PyHamt_Fini(void)
+{
+ Py_CLEAR(_empty_hamt);
+ Py_CLEAR(_empty_bitmap_node);
+}
diff --git a/Python/pylifecycle.c b/Python/pylifecycle.c
index 2f61db0..d46784a 100644
--- a/Python/pylifecycle.c
+++ b/Python/pylifecycle.c
@@ -4,6 +4,8 @@
#include "Python-ast.h"
#undef Yield /* undefine macro conflicting with winbase.h */
+#include "internal/context.h"
+#include "internal/hamt.h"
#include "internal/pystate.h"
#include "grammar.h"
#include "node.h"
@@ -758,6 +760,9 @@ _Py_InitializeCore(const _PyCoreConfig *core_config)
return _Py_INIT_ERR("can't initialize warnings");
}
+ if (!_PyContext_Init())
+ return _Py_INIT_ERR("can't init context");
+
/* This call sets up builtin and frozen import support */
if (!interp->core_config._disable_importlib) {
err = initimport(interp, sysmod);
@@ -1176,6 +1181,7 @@ Py_FinalizeEx(void)
_Py_HashRandomization_Fini();
_PyArg_Fini();
PyAsyncGen_Fini();
+ _PyContext_Fini();
/* Cleanup Unicode implementation */
_PyUnicode_Fini();
diff --git a/Python/pystate.c b/Python/pystate.c
index 9c25a26..909d831 100644
--- a/Python/pystate.c
+++ b/Python/pystate.c
@@ -173,6 +173,8 @@ PyInterpreterState_New(void)
}
HEAD_UNLOCK();
+ interp->tstate_next_unique_id = 0;
+
return interp;
}
@@ -313,6 +315,11 @@ new_threadstate(PyInterpreterState *interp, int init)
tstate->async_gen_firstiter = NULL;
tstate->async_gen_finalizer = NULL;
+ tstate->context = NULL;
+ tstate->context_ver = 1;
+
+ tstate->id = ++interp->tstate_next_unique_id;
+
if (init)
_PyThreadState_Init(tstate);
@@ -499,6 +506,8 @@ PyThreadState_Clear(PyThreadState *tstate)
Py_CLEAR(tstate->coroutine_wrapper);
Py_CLEAR(tstate->async_gen_firstiter);
Py_CLEAR(tstate->async_gen_finalizer);
+
+ Py_CLEAR(tstate->context);
}
diff --git a/Tools/msi/lib/lib_files.wxs b/Tools/msi/lib/lib_files.wxs
index 5a72612..46ddcb4 100644
--- a/Tools/msi/lib/lib_files.wxs
+++ b/Tools/msi/lib/lib_files.wxs
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<Wix xmlns="http://schemas.microsoft.com/wix/2006/wi">
- <?define exts=pyexpat;select;unicodedata;winsound;_bz2;_elementtree;_socket;_ssl;_msi;_ctypes;_hashlib;_multiprocessing;_lzma;_decimal;_overlapped;_sqlite3;_asyncio;_queue;_distutils_findvs ?>
+ <?define exts=pyexpat;select;unicodedata;winsound;_bz2;_elementtree;_socket;_ssl;_msi;_ctypes;_hashlib;_multiprocessing;_lzma;_decimal;_overlapped;_sqlite3;_asyncio;_queue;_distutils_findvs;_contextvars ?>
<Fragment>
<ComponentGroup Id="lib_extensions">
<?foreach ext in $(var.exts)?>
diff --git a/setup.py b/setup.py
index 1da40a4..258094e 100644
--- a/setup.py
+++ b/setup.py
@@ -644,6 +644,9 @@ class PyBuildExt(build_ext):
# array objects
exts.append( Extension('array', ['arraymodule.c']) )
+ # Context Variables
+ exts.append( Extension('_contextvars', ['_contextvarsmodule.c']) )
+
shared_math = 'Modules/_math.o'
# complex math library functions
exts.append( Extension('cmath', ['cmathmodule.c'],