summaryrefslogtreecommitdiffstats
path: root/Modules/_decimal
diff options
context:
space:
mode:
authorYury Selivanov <yury@magic.io>2018-01-27 18:46:46 (GMT)
committerGitHub <noreply@github.com>2018-01-27 18:46:46 (GMT)
commitf13f12d8daa587b5fcc66fe3ed1090a5dadab289 (patch)
tree50f8217d1fcbcdf447e2a21b9e903cb6824ac8af /Modules/_decimal
parentbc4123b0b380edda774b8bff2fa1bcc96453b440 (diff)
downloadcpython-f13f12d8daa587b5fcc66fe3ed1090a5dadab289.zip
cpython-f13f12d8daa587b5fcc66fe3ed1090a5dadab289.tar.gz
cpython-f13f12d8daa587b5fcc66fe3ed1090a5dadab289.tar.bz2
bpo-32630: Use contextvars in decimal (GH-5278)
Diffstat (limited to 'Modules/_decimal')
-rw-r--r--Modules/_decimal/_decimal.c120
1 files changed, 31 insertions, 89 deletions
diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c
index 18fa2e4..fddb39e 100644
--- a/Modules/_decimal/_decimal.c
+++ b/Modules/_decimal/_decimal.c
@@ -122,10 +122,7 @@ incr_false(void)
}
-/* Key for thread state dictionary */
-static PyObject *tls_context_key = NULL;
-/* Invariant: NULL or the most recently accessed thread local context */
-static PyDecContextObject *cached_context = NULL;
+static PyContextVar *current_context_var;
/* Template for creating new thread contexts, calling Context() without
* arguments and initializing the module_context on first access. */
@@ -1220,10 +1217,6 @@ context_new(PyTypeObject *type, PyObject *args UNUSED, PyObject *kwds UNUSED)
static void
context_dealloc(PyDecContextObject *self)
{
- if (self == cached_context) {
- cached_context = NULL;
- }
-
Py_XDECREF(self->traps);
Py_XDECREF(self->flags);
Py_TYPE(self)->tp_free(self);
@@ -1498,69 +1491,38 @@ static PyGetSetDef context_getsets [] =
* operation.
*/
-/* Get the context from the thread state dictionary. */
static PyObject *
-current_context_from_dict(void)
+init_current_context(void)
{
- PyObject *dict;
- PyObject *tl_context;
- PyThreadState *tstate;
-
- dict = PyThreadState_GetDict();
- if (dict == NULL) {
- PyErr_SetString(PyExc_RuntimeError,
- "cannot get thread state");
+ PyObject *tl_context = context_copy(default_context_template, NULL);
+ if (tl_context == NULL) {
return NULL;
}
+ CTX(tl_context)->status = 0;
- tl_context = PyDict_GetItemWithError(dict, tls_context_key);
- if (tl_context != NULL) {
- /* We already have a thread local context. */
- CONTEXT_CHECK(tl_context);
- }
- else {
- if (PyErr_Occurred()) {
- return NULL;
- }
-
- /* Set up a new thread local context. */
- tl_context = context_copy(default_context_template, NULL);
- if (tl_context == NULL) {
- return NULL;
- }
- CTX(tl_context)->status = 0;
-
- if (PyDict_SetItem(dict, tls_context_key, tl_context) < 0) {
- Py_DECREF(tl_context);
- return NULL;
- }
+ PyContextToken *tok = PyContextVar_Set(current_context_var, tl_context);
+ if (tok == NULL) {
Py_DECREF(tl_context);
+ return NULL;
}
+ Py_DECREF(tok);
- /* Cache the context of the current thread, assuming that it
- * will be accessed several times before a thread switch. */
- tstate = PyThreadState_GET();
- if (tstate) {
- cached_context = (PyDecContextObject *)tl_context;
- cached_context->tstate = tstate;
- }
-
- /* Borrowed reference with refcount==1 */
return tl_context;
}
-/* Return borrowed reference to thread local context. */
-static PyObject *
+static inline PyObject *
current_context(void)
{
- PyThreadState *tstate;
+ PyObject *tl_context;
+ if (PyContextVar_Get(current_context_var, NULL, &tl_context) < 0) {
+ return NULL;
+ }
- tstate = PyThreadState_GET();
- if (cached_context && cached_context->tstate == tstate) {
- return (PyObject *)cached_context;
+ if (tl_context != NULL) {
+ return tl_context;
}
- return current_context_from_dict();
+ return init_current_context();
}
/* ctxobj := borrowed reference to the current context */
@@ -1568,47 +1530,22 @@ current_context(void)
ctxobj = current_context(); \
if (ctxobj == NULL) { \
return NULL; \
- }
-
-/* ctx := pointer to the mpd_context_t struct of the current context */
-#define CURRENT_CONTEXT_ADDR(ctx) { \
- PyObject *_c_t_x_o_b_j = current_context(); \
- if (_c_t_x_o_b_j == NULL) { \
- return NULL; \
- } \
- ctx = CTX(_c_t_x_o_b_j); \
-}
+ } \
+ Py_DECREF(ctxobj);
/* Return a new reference to the current context */
static PyObject *
PyDec_GetCurrentContext(PyObject *self UNUSED, PyObject *args UNUSED)
{
- PyObject *context;
-
- context = current_context();
- if (context == NULL) {
- return NULL;
- }
-
- Py_INCREF(context);
- return context;
+ return current_context();
}
/* Set the thread local context to a new context, decrement old reference */
static PyObject *
PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v)
{
- PyObject *dict;
-
CONTEXT_CHECK(v);
- dict = PyThreadState_GetDict();
- if (dict == NULL) {
- PyErr_SetString(PyExc_RuntimeError,
- "cannot get thread state");
- return NULL;
- }
-
/* If the new context is one of the templates, make a copy.
* This is the current behavior of decimal.py. */
if (v == default_context_template ||
@@ -1624,13 +1561,13 @@ PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v)
Py_INCREF(v);
}
- cached_context = NULL;
- if (PyDict_SetItem(dict, tls_context_key, v) < 0) {
- Py_DECREF(v);
+ PyContextToken *tok = PyContextVar_Set(current_context_var, v);
+ Py_DECREF(v);
+ if (tok == NULL) {
return NULL;
}
+ Py_DECREF(tok);
- Py_DECREF(v);
Py_RETURN_NONE;
}
@@ -4458,6 +4395,7 @@ _dec_hash(PyDecObject *v)
if (context == NULL) {
return -1;
}
+ Py_DECREF(context);
if (mpd_isspecial(MPD(v))) {
if (mpd_issnan(MPD(v))) {
@@ -5599,6 +5537,11 @@ PyInit__decimal(void)
mpd_free = PyMem_Free;
mpd_setminalloc(_Py_DEC_MINALLOC);
+ /* Init context variable */
+ current_context_var = PyContextVar_New("decimal_context", NULL);
+ if (current_context_var == NULL) {
+ goto error;
+ }
/* Init external C-API functions */
_py_long_multiply = PyLong_Type.tp_as_number->nb_multiply;
@@ -5768,7 +5711,6 @@ PyInit__decimal(void)
CHECK_INT(PyModule_AddObject(m, "DefaultContext",
default_context_template));
- ASSIGN_PTR(tls_context_key, PyUnicode_FromString("___DECIMAL_CTX__"));
Py_INCREF(Py_True);
CHECK_INT(PyModule_AddObject(m, "HAVE_THREADS", Py_True));
@@ -5827,9 +5769,9 @@ error:
Py_CLEAR(SignalTuple); /* GCOV_NOT_REACHED */
Py_CLEAR(DecimalTuple); /* GCOV_NOT_REACHED */
Py_CLEAR(default_context_template); /* GCOV_NOT_REACHED */
- Py_CLEAR(tls_context_key); /* GCOV_NOT_REACHED */
Py_CLEAR(basic_context_template); /* GCOV_NOT_REACHED */
Py_CLEAR(extended_context_template); /* GCOV_NOT_REACHED */
+ Py_CLEAR(current_context_var); /* GCOV_NOT_REACHED */
Py_CLEAR(m); /* GCOV_NOT_REACHED */
return NULL; /* GCOV_NOT_REACHED */