summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYury Selivanov <yury@magic.io>2018-01-27 18:46:46 (GMT)
committerGitHub <noreply@github.com>2018-01-27 18:46:46 (GMT)
commitf13f12d8daa587b5fcc66fe3ed1090a5dadab289 (patch)
tree50f8217d1fcbcdf447e2a21b9e903cb6824ac8af
parentbc4123b0b380edda774b8bff2fa1bcc96453b440 (diff)
downloadcpython-f13f12d8daa587b5fcc66fe3ed1090a5dadab289.zip
cpython-f13f12d8daa587b5fcc66fe3ed1090a5dadab289.tar.gz
cpython-f13f12d8daa587b5fcc66fe3ed1090a5dadab289.tar.bz2
bpo-32630: Use contextvars in decimal (GH-5278)
-rw-r--r--Lib/_pydecimal.py20
-rw-r--r--Lib/test/test_asyncio/test_context.py29
-rw-r--r--Misc/NEWS.d/next/Library/2018-01-23-01-57-36.bpo-32630.6KRHBs.rst1
-rw-r--r--Modules/_decimal/_decimal.c120
4 files changed, 70 insertions, 100 deletions
diff --git a/Lib/_pydecimal.py b/Lib/_pydecimal.py
index a1662bb..3596900 100644
--- a/Lib/_pydecimal.py
+++ b/Lib/_pydecimal.py
@@ -433,13 +433,11 @@ _rounding_modes = (ROUND_DOWN, ROUND_HALF_UP, ROUND_HALF_EVEN, ROUND_CEILING,
# The getcontext() and setcontext() function manage access to a thread-local
# current context.
-import threading
+import contextvars
-local = threading.local()
-if hasattr(local, '__decimal_context__'):
- del local.__decimal_context__
+_current_context_var = contextvars.ContextVar('decimal_context')
-def getcontext(_local=local):
+def getcontext():
"""Returns this thread's context.
If this thread does not yet have a context, returns
@@ -447,20 +445,20 @@ def getcontext(_local=local):
New contexts are copies of DefaultContext.
"""
try:
- return _local.__decimal_context__
- except AttributeError:
+ return _current_context_var.get()
+ except LookupError:
context = Context()
- _local.__decimal_context__ = context
+ _current_context_var.set(context)
return context
-def setcontext(context, _local=local):
+def setcontext(context):
"""Set this thread's context to context."""
if context in (DefaultContext, BasicContext, ExtendedContext):
context = context.copy()
context.clear_flags()
- _local.__decimal_context__ = context
+ _current_context_var.set(context)
-del threading, local # Don't contaminate the namespace
+del contextvars # Don't contaminate the namespace
def localcontext(ctx=None):
"""Return a context manager for a copy of the supplied context
diff --git a/Lib/test/test_asyncio/test_context.py b/Lib/test/test_asyncio/test_context.py
new file mode 100644
index 0000000..6abddd9f2
--- /dev/null
+++ b/Lib/test/test_asyncio/test_context.py
@@ -0,0 +1,29 @@
+import asyncio
+import decimal
+import unittest
+
+
+class DecimalContextTest(unittest.TestCase):
+
+ def test_asyncio_task_decimal_context(self):
+ async def fractions(t, precision, x, y):
+ with decimal.localcontext() as ctx:
+ ctx.prec = precision
+ a = decimal.Decimal(x) / decimal.Decimal(y)
+ await asyncio.sleep(t)
+ b = decimal.Decimal(x) / decimal.Decimal(y ** 2)
+ return a, b
+
+ async def main():
+ r1, r2 = await asyncio.gather(
+ fractions(0.1, 3, 1, 3), fractions(0.2, 6, 1, 3))
+
+ return r1, r2
+
+ r1, r2 = asyncio.run(main())
+
+ self.assertEqual(str(r1[0]), '0.333')
+ self.assertEqual(str(r1[1]), '0.111')
+
+ self.assertEqual(str(r2[0]), '0.333333')
+ self.assertEqual(str(r2[1]), '0.111111')
diff --git a/Misc/NEWS.d/next/Library/2018-01-23-01-57-36.bpo-32630.6KRHBs.rst b/Misc/NEWS.d/next/Library/2018-01-23-01-57-36.bpo-32630.6KRHBs.rst
new file mode 100644
index 0000000..1bbcbb1
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-01-23-01-57-36.bpo-32630.6KRHBs.rst
@@ -0,0 +1 @@
+Refactor decimal module to use contextvars to store decimal context.
diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c
index 18fa2e4..fddb39e 100644
--- a/Modules/_decimal/_decimal.c
+++ b/Modules/_decimal/_decimal.c
@@ -122,10 +122,7 @@ incr_false(void)
}
-/* Key for thread state dictionary */
-static PyObject *tls_context_key = NULL;
-/* Invariant: NULL or the most recently accessed thread local context */
-static PyDecContextObject *cached_context = NULL;
+static PyContextVar *current_context_var;
/* Template for creating new thread contexts, calling Context() without
* arguments and initializing the module_context on first access. */
@@ -1220,10 +1217,6 @@ context_new(PyTypeObject *type, PyObject *args UNUSED, PyObject *kwds UNUSED)
static void
context_dealloc(PyDecContextObject *self)
{
- if (self == cached_context) {
- cached_context = NULL;
- }
-
Py_XDECREF(self->traps);
Py_XDECREF(self->flags);
Py_TYPE(self)->tp_free(self);
@@ -1498,69 +1491,38 @@ static PyGetSetDef context_getsets [] =
* operation.
*/
-/* Get the context from the thread state dictionary. */
static PyObject *
-current_context_from_dict(void)
+init_current_context(void)
{
- PyObject *dict;
- PyObject *tl_context;
- PyThreadState *tstate;
-
- dict = PyThreadState_GetDict();
- if (dict == NULL) {
- PyErr_SetString(PyExc_RuntimeError,
- "cannot get thread state");
+ PyObject *tl_context = context_copy(default_context_template, NULL);
+ if (tl_context == NULL) {
return NULL;
}
+ CTX(tl_context)->status = 0;
- tl_context = PyDict_GetItemWithError(dict, tls_context_key);
- if (tl_context != NULL) {
- /* We already have a thread local context. */
- CONTEXT_CHECK(tl_context);
- }
- else {
- if (PyErr_Occurred()) {
- return NULL;
- }
-
- /* Set up a new thread local context. */
- tl_context = context_copy(default_context_template, NULL);
- if (tl_context == NULL) {
- return NULL;
- }
- CTX(tl_context)->status = 0;
-
- if (PyDict_SetItem(dict, tls_context_key, tl_context) < 0) {
- Py_DECREF(tl_context);
- return NULL;
- }
+ PyContextToken *tok = PyContextVar_Set(current_context_var, tl_context);
+ if (tok == NULL) {
Py_DECREF(tl_context);
+ return NULL;
}
+ Py_DECREF(tok);
- /* Cache the context of the current thread, assuming that it
- * will be accessed several times before a thread switch. */
- tstate = PyThreadState_GET();
- if (tstate) {
- cached_context = (PyDecContextObject *)tl_context;
- cached_context->tstate = tstate;
- }
-
- /* Borrowed reference with refcount==1 */
return tl_context;
}
-/* Return borrowed reference to thread local context. */
-static PyObject *
+static inline PyObject *
current_context(void)
{
- PyThreadState *tstate;
+ PyObject *tl_context;
+ if (PyContextVar_Get(current_context_var, NULL, &tl_context) < 0) {
+ return NULL;
+ }
- tstate = PyThreadState_GET();
- if (cached_context && cached_context->tstate == tstate) {
- return (PyObject *)cached_context;
+ if (tl_context != NULL) {
+ return tl_context;
}
- return current_context_from_dict();
+ return init_current_context();
}
/* ctxobj := borrowed reference to the current context */
@@ -1568,47 +1530,22 @@ current_context(void)
ctxobj = current_context(); \
if (ctxobj == NULL) { \
return NULL; \
- }
-
-/* ctx := pointer to the mpd_context_t struct of the current context */
-#define CURRENT_CONTEXT_ADDR(ctx) { \
- PyObject *_c_t_x_o_b_j = current_context(); \
- if (_c_t_x_o_b_j == NULL) { \
- return NULL; \
- } \
- ctx = CTX(_c_t_x_o_b_j); \
-}
+ } \
+ Py_DECREF(ctxobj);
/* Return a new reference to the current context */
static PyObject *
PyDec_GetCurrentContext(PyObject *self UNUSED, PyObject *args UNUSED)
{
- PyObject *context;
-
- context = current_context();
- if (context == NULL) {
- return NULL;
- }
-
- Py_INCREF(context);
- return context;
+ return current_context();
}
/* Set the thread local context to a new context, decrement old reference */
static PyObject *
PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v)
{
- PyObject *dict;
-
CONTEXT_CHECK(v);
- dict = PyThreadState_GetDict();
- if (dict == NULL) {
- PyErr_SetString(PyExc_RuntimeError,
- "cannot get thread state");
- return NULL;
- }
-
/* If the new context is one of the templates, make a copy.
* This is the current behavior of decimal.py. */
if (v == default_context_template ||
@@ -1624,13 +1561,13 @@ PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v)
Py_INCREF(v);
}
- cached_context = NULL;
- if (PyDict_SetItem(dict, tls_context_key, v) < 0) {
- Py_DECREF(v);
+ PyContextToken *tok = PyContextVar_Set(current_context_var, v);
+ Py_DECREF(v);
+ if (tok == NULL) {
return NULL;
}
+ Py_DECREF(tok);
- Py_DECREF(v);
Py_RETURN_NONE;
}
@@ -4458,6 +4395,7 @@ _dec_hash(PyDecObject *v)
if (context == NULL) {
return -1;
}
+ Py_DECREF(context);
if (mpd_isspecial(MPD(v))) {
if (mpd_issnan(MPD(v))) {
@@ -5599,6 +5537,11 @@ PyInit__decimal(void)
mpd_free = PyMem_Free;
mpd_setminalloc(_Py_DEC_MINALLOC);
+ /* Init context variable */
+ current_context_var = PyContextVar_New("decimal_context", NULL);
+ if (current_context_var == NULL) {
+ goto error;
+ }
/* Init external C-API functions */
_py_long_multiply = PyLong_Type.tp_as_number->nb_multiply;
@@ -5768,7 +5711,6 @@ PyInit__decimal(void)
CHECK_INT(PyModule_AddObject(m, "DefaultContext",
default_context_template));
- ASSIGN_PTR(tls_context_key, PyUnicode_FromString("___DECIMAL_CTX__"));
Py_INCREF(Py_True);
CHECK_INT(PyModule_AddObject(m, "HAVE_THREADS", Py_True));
@@ -5827,9 +5769,9 @@ error:
Py_CLEAR(SignalTuple); /* GCOV_NOT_REACHED */
Py_CLEAR(DecimalTuple); /* GCOV_NOT_REACHED */
Py_CLEAR(default_context_template); /* GCOV_NOT_REACHED */
- Py_CLEAR(tls_context_key); /* GCOV_NOT_REACHED */
Py_CLEAR(basic_context_template); /* GCOV_NOT_REACHED */
Py_CLEAR(extended_context_template); /* GCOV_NOT_REACHED */
+ Py_CLEAR(current_context_var); /* GCOV_NOT_REACHED */
Py_CLEAR(m); /* GCOV_NOT_REACHED */
return NULL; /* GCOV_NOT_REACHED */