1 #include "Python.h"
2 #include "pycore_call.h"          // _PyObject_VectorcallTstate()
3 #include "pycore_context.h"
4 #include "pycore_gc.h"            // _PyObject_GC_MAY_BE_TRACKED()
5 #include "pycore_hamt.h"
6 #include "pycore_initconfig.h"    // _PyStatus_OK()
7 #include "pycore_object.h"
8 #include "pycore_pyerrors.h"
9 #include "pycore_pystate.h"       // _PyThreadState_GET()
10 #include "structmember.h"         // PyMemberDef
11 
12 
13 #include "clinic/context.c.h"
14 /*[clinic input]
15 module _contextvars
16 [clinic start generated code]*/
17 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=a0955718c8b8cea6]*/
18 
19 
20 #define ENSURE_Context(o, err_ret)                                  \
21     if (!PyContext_CheckExact(o)) {                                 \
22         PyErr_SetString(PyExc_TypeError,                            \
23                         "an instance of Context was expected");     \
24         return err_ret;                                             \
25     }
26 
27 #define ENSURE_ContextVar(o, err_ret)                               \
28     if (!PyContextVar_CheckExact(o)) {                              \
29         PyErr_SetString(PyExc_TypeError,                            \
30                        "an instance of ContextVar was expected");   \
31         return err_ret;                                             \
32     }
33 
34 #define ENSURE_ContextToken(o, err_ret)                             \
35     if (!PyContextToken_CheckExact(o)) {                            \
36         PyErr_SetString(PyExc_TypeError,                            \
37                         "an instance of Token was expected");       \
38         return err_ret;                                             \
39     }
40 
41 
42 /////////////////////////// Context API
43 
44 
45 static PyContext *
46 context_new_empty(void);
47 
48 static PyContext *
49 context_new_from_vars(PyHamtObject *vars);
50 
51 static inline PyContext *
52 context_get(void);
53 
54 static PyContextToken *
55 token_new(PyContext *ctx, PyContextVar *var, PyObject *val);
56 
57 static PyContextVar *
58 contextvar_new(PyObject *name, PyObject *def);
59 
60 static int
61 contextvar_set(PyContextVar *var, PyObject *val);
62 
63 static int
64 contextvar_del(PyContextVar *var);
65 
66 
67 #if PyContext_MAXFREELIST > 0
68 static struct _Py_context_state *
get_context_state(void)69 get_context_state(void)
70 {
71     PyInterpreterState *interp = _PyInterpreterState_GET();
72     return &interp->context;
73 }
74 #endif
75 
76 
77 PyObject *
_PyContext_NewHamtForTests(void)78 _PyContext_NewHamtForTests(void)
79 {
80     return (PyObject *)_PyHamt_New();
81 }
82 
83 
84 PyObject *
PyContext_New(void)85 PyContext_New(void)
86 {
87     return (PyObject *)context_new_empty();
88 }
89 
90 
91 PyObject *
PyContext_Copy(PyObject * octx)92 PyContext_Copy(PyObject * octx)
93 {
94     ENSURE_Context(octx, NULL)
95     PyContext *ctx = (PyContext *)octx;
96     return (PyObject *)context_new_from_vars(ctx->ctx_vars);
97 }
98 
99 
100 PyObject *
PyContext_CopyCurrent(void)101 PyContext_CopyCurrent(void)
102 {
103     PyContext *ctx = context_get();
104     if (ctx == NULL) {
105         return NULL;
106     }
107 
108     return (PyObject *)context_new_from_vars(ctx->ctx_vars);
109 }
110 
111 
112 static int
_PyContext_Enter(PyThreadState * ts,PyObject * octx)113 _PyContext_Enter(PyThreadState *ts, PyObject *octx)
114 {
115     ENSURE_Context(octx, -1)
116     PyContext *ctx = (PyContext *)octx;
117 
118     if (ctx->ctx_entered) {
119         _PyErr_Format(ts, PyExc_RuntimeError,
120                       "cannot enter context: %R is already entered", ctx);
121         return -1;
122     }
123 
124     ctx->ctx_prev = (PyContext *)ts->context;  /* borrow */
125     ctx->ctx_entered = 1;
126 
127     Py_INCREF(ctx);
128     ts->context = (PyObject *)ctx;
129     ts->context_ver++;
130 
131     return 0;
132 }
133 
134 
135 int
PyContext_Enter(PyObject * octx)136 PyContext_Enter(PyObject *octx)
137 {
138     PyThreadState *ts = _PyThreadState_GET();
139     assert(ts != NULL);
140     return _PyContext_Enter(ts, octx);
141 }
142 
143 
144 static int
_PyContext_Exit(PyThreadState * ts,PyObject * octx)145 _PyContext_Exit(PyThreadState *ts, PyObject *octx)
146 {
147     ENSURE_Context(octx, -1)
148     PyContext *ctx = (PyContext *)octx;
149 
150     if (!ctx->ctx_entered) {
151         PyErr_Format(PyExc_RuntimeError,
152                      "cannot exit context: %R has not been entered", ctx);
153         return -1;
154     }
155 
156     if (ts->context != (PyObject *)ctx) {
157         /* Can only happen if someone misuses the C API */
158         PyErr_SetString(PyExc_RuntimeError,
159                         "cannot exit context: thread state references "
160                         "a different context object");
161         return -1;
162     }
163 
164     Py_SETREF(ts->context, (PyObject *)ctx->ctx_prev);
165     ts->context_ver++;
166 
167     ctx->ctx_prev = NULL;
168     ctx->ctx_entered = 0;
169 
170     return 0;
171 }
172 
173 int
PyContext_Exit(PyObject * octx)174 PyContext_Exit(PyObject *octx)
175 {
176     PyThreadState *ts = _PyThreadState_GET();
177     assert(ts != NULL);
178     return _PyContext_Exit(ts, octx);
179 }
180 
181 
182 PyObject *
PyContextVar_New(const char * name,PyObject * def)183 PyContextVar_New(const char *name, PyObject *def)
184 {
185     PyObject *pyname = PyUnicode_FromString(name);
186     if (pyname == NULL) {
187         return NULL;
188     }
189     PyContextVar *var = contextvar_new(pyname, def);
190     Py_DECREF(pyname);
191     return (PyObject *)var;
192 }
193 
194 
195 int
PyContextVar_Get(PyObject * ovar,PyObject * def,PyObject ** val)196 PyContextVar_Get(PyObject *ovar, PyObject *def, PyObject **val)
197 {
198     ENSURE_ContextVar(ovar, -1)
199     PyContextVar *var = (PyContextVar *)ovar;
200 
201     PyThreadState *ts = _PyThreadState_GET();
202     assert(ts != NULL);
203     if (ts->context == NULL) {
204         goto not_found;
205     }
206 
207     if (var->var_cached != NULL &&
208             var->var_cached_tsid == ts->id &&
209             var->var_cached_tsver == ts->context_ver)
210     {
211         *val = var->var_cached;
212         goto found;
213     }
214 
215     assert(PyContext_CheckExact(ts->context));
216     PyHamtObject *vars = ((PyContext *)ts->context)->ctx_vars;
217 
218     PyObject *found = NULL;
219     int res = _PyHamt_Find(vars, (PyObject*)var, &found);
220     if (res < 0) {
221         goto error;
222     }
223     if (res == 1) {
224         assert(found != NULL);
225         var->var_cached = found;  /* borrow */
226         var->var_cached_tsid = ts->id;
227         var->var_cached_tsver = ts->context_ver;
228 
229         *val = found;
230         goto found;
231     }
232 
233 not_found:
234     if (def == NULL) {
235         if (var->var_default != NULL) {
236             *val = var->var_default;
237             goto found;
238         }
239 
240         *val = NULL;
241         goto found;
242     }
243     else {
244         *val = def;
245         goto found;
246    }
247 
248 found:
249     Py_XINCREF(*val);
250     return 0;
251 
252 error:
253     *val = NULL;
254     return -1;
255 }
256 
257 
258 PyObject *
PyContextVar_Set(PyObject * ovar,PyObject * val)259 PyContextVar_Set(PyObject *ovar, PyObject *val)
260 {
261     ENSURE_ContextVar(ovar, NULL)
262     PyContextVar *var = (PyContextVar *)ovar;
263 
264     if (!PyContextVar_CheckExact(var)) {
265         PyErr_SetString(
266             PyExc_TypeError, "an instance of ContextVar was expected");
267         return NULL;
268     }
269 
270     PyContext *ctx = context_get();
271     if (ctx == NULL) {
272         return NULL;
273     }
274 
275     PyObject *old_val = NULL;
276     int found = _PyHamt_Find(ctx->ctx_vars, (PyObject *)var, &old_val);
277     if (found < 0) {
278         return NULL;
279     }
280 
281     Py_XINCREF(old_val);
282     PyContextToken *tok = token_new(ctx, var, old_val);
283     Py_XDECREF(old_val);
284 
285     if (contextvar_set(var, val)) {
286         Py_DECREF(tok);
287         return NULL;
288     }
289 
290     return (PyObject *)tok;
291 }
292 
293 
294 int
PyContextVar_Reset(PyObject * ovar,PyObject * otok)295 PyContextVar_Reset(PyObject *ovar, PyObject *otok)
296 {
297     ENSURE_ContextVar(ovar, -1)
298     ENSURE_ContextToken(otok, -1)
299     PyContextVar *var = (PyContextVar *)ovar;
300     PyContextToken *tok = (PyContextToken *)otok;
301 
302     if (tok->tok_used) {
303         PyErr_Format(PyExc_RuntimeError,
304                      "%R has already been used once", tok);
305         return -1;
306     }
307 
308     if (var != tok->tok_var) {
309         PyErr_Format(PyExc_ValueError,
310                      "%R was created by a different ContextVar", tok);
311         return -1;
312     }
313 
314     PyContext *ctx = context_get();
315     if (ctx != tok->tok_ctx) {
316         PyErr_Format(PyExc_ValueError,
317                      "%R was created in a different Context", tok);
318         return -1;
319     }
320 
321     tok->tok_used = 1;
322 
323     if (tok->tok_oldval == NULL) {
324         return contextvar_del(var);
325     }
326     else {
327         return contextvar_set(var, tok->tok_oldval);
328     }
329 }
330 
331 
332 /////////////////////////// PyContext
333 
334 /*[clinic input]
335 class _contextvars.Context "PyContext *" "&PyContext_Type"
336 [clinic start generated code]*/
337 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=bdf87f8e0cb580e8]*/
338 
339 
340 static inline PyContext *
_context_alloc(void)341 _context_alloc(void)
342 {
343     PyContext *ctx;
344 #if PyContext_MAXFREELIST > 0
345     struct _Py_context_state *state = get_context_state();
346 #ifdef Py_DEBUG
347     // _context_alloc() must not be called after _PyContext_Fini()
348     assert(state->numfree != -1);
349 #endif
350     if (state->numfree) {
351         state->numfree--;
352         ctx = state->freelist;
353         state->freelist = (PyContext *)ctx->ctx_weakreflist;
354         OBJECT_STAT_INC(from_freelist);
355         ctx->ctx_weakreflist = NULL;
356         _Py_NewReference((PyObject *)ctx);
357     }
358     else
359 #endif
360     {
361         ctx = PyObject_GC_New(PyContext, &PyContext_Type);
362         if (ctx == NULL) {
363             return NULL;
364         }
365     }
366 
367     ctx->ctx_vars = NULL;
368     ctx->ctx_prev = NULL;
369     ctx->ctx_entered = 0;
370     ctx->ctx_weakreflist = NULL;
371 
372     return ctx;
373 }
374 
375 
376 static PyContext *
context_new_empty(void)377 context_new_empty(void)
378 {
379     PyContext *ctx = _context_alloc();
380     if (ctx == NULL) {
381         return NULL;
382     }
383 
384     ctx->ctx_vars = _PyHamt_New();
385     if (ctx->ctx_vars == NULL) {
386         Py_DECREF(ctx);
387         return NULL;
388     }
389 
390     _PyObject_GC_TRACK(ctx);
391     return ctx;
392 }
393 
394 
395 static PyContext *
context_new_from_vars(PyHamtObject * vars)396 context_new_from_vars(PyHamtObject *vars)
397 {
398     PyContext *ctx = _context_alloc();
399     if (ctx == NULL) {
400         return NULL;
401     }
402 
403     Py_INCREF(vars);
404     ctx->ctx_vars = vars;
405 
406     _PyObject_GC_TRACK(ctx);
407     return ctx;
408 }
409 
410 
411 static inline PyContext *
context_get(void)412 context_get(void)
413 {
414     PyThreadState *ts = _PyThreadState_GET();
415     assert(ts != NULL);
416     PyContext *current_ctx = (PyContext *)ts->context;
417     if (current_ctx == NULL) {
418         current_ctx = context_new_empty();
419         if (current_ctx == NULL) {
420             return NULL;
421         }
422         ts->context = (PyObject *)current_ctx;
423     }
424     return current_ctx;
425 }
426 
427 static int
context_check_key_type(PyObject * key)428 context_check_key_type(PyObject *key)
429 {
430     if (!PyContextVar_CheckExact(key)) {
431         // abort();
432         PyErr_Format(PyExc_TypeError,
433                      "a ContextVar key was expected, got %R", key);
434         return -1;
435     }
436     return 0;
437 }
438 
439 static PyObject *
context_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)440 context_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
441 {
442     if (PyTuple_Size(args) || (kwds != NULL && PyDict_Size(kwds))) {
443         PyErr_SetString(
444             PyExc_TypeError, "Context() does not accept any arguments");
445         return NULL;
446     }
447     return PyContext_New();
448 }
449 
450 static int
context_tp_clear(PyContext * self)451 context_tp_clear(PyContext *self)
452 {
453     Py_CLEAR(self->ctx_prev);
454     Py_CLEAR(self->ctx_vars);
455     return 0;
456 }
457 
458 static int
context_tp_traverse(PyContext * self,visitproc visit,void * arg)459 context_tp_traverse(PyContext *self, visitproc visit, void *arg)
460 {
461     Py_VISIT(self->ctx_prev);
462     Py_VISIT(self->ctx_vars);
463     return 0;
464 }
465 
466 static void
context_tp_dealloc(PyContext * self)467 context_tp_dealloc(PyContext *self)
468 {
469     _PyObject_GC_UNTRACK(self);
470 
471     if (self->ctx_weakreflist != NULL) {
472         PyObject_ClearWeakRefs((PyObject*)self);
473     }
474     (void)context_tp_clear(self);
475 
476 #if PyContext_MAXFREELIST > 0
477     struct _Py_context_state *state = get_context_state();
478 #ifdef Py_DEBUG
479     // _context_alloc() must not be called after _PyContext_Fini()
480     assert(state->numfree != -1);
481 #endif
482     if (state->numfree < PyContext_MAXFREELIST) {
483         state->numfree++;
484         self->ctx_weakreflist = (PyObject *)state->freelist;
485         state->freelist = self;
486         OBJECT_STAT_INC(to_freelist);
487     }
488     else
489 #endif
490     {
491         Py_TYPE(self)->tp_free(self);
492     }
493 }
494 
495 static PyObject *
context_tp_iter(PyContext * self)496 context_tp_iter(PyContext *self)
497 {
498     return _PyHamt_NewIterKeys(self->ctx_vars);
499 }
500 
501 static PyObject *
context_tp_richcompare(PyObject * v,PyObject * w,int op)502 context_tp_richcompare(PyObject *v, PyObject *w, int op)
503 {
504     if (!PyContext_CheckExact(v) || !PyContext_CheckExact(w) ||
505             (op != Py_EQ && op != Py_NE))
506     {
507         Py_RETURN_NOTIMPLEMENTED;
508     }
509 
510     int res = _PyHamt_Eq(
511         ((PyContext *)v)->ctx_vars, ((PyContext *)w)->ctx_vars);
512     if (res < 0) {
513         return NULL;
514     }
515 
516     if (op == Py_NE) {
517         res = !res;
518     }
519 
520     if (res) {
521         Py_RETURN_TRUE;
522     }
523     else {
524         Py_RETURN_FALSE;
525     }
526 }
527 
528 static Py_ssize_t
context_tp_len(PyContext * self)529 context_tp_len(PyContext *self)
530 {
531     return _PyHamt_Len(self->ctx_vars);
532 }
533 
534 static PyObject *
context_tp_subscript(PyContext * self,PyObject * key)535 context_tp_subscript(PyContext *self, PyObject *key)
536 {
537     if (context_check_key_type(key)) {
538         return NULL;
539     }
540     PyObject *val = NULL;
541     int found = _PyHamt_Find(self->ctx_vars, key, &val);
542     if (found < 0) {
543         return NULL;
544     }
545     if (found == 0) {
546         PyErr_SetObject(PyExc_KeyError, key);
547         return NULL;
548     }
549     Py_INCREF(val);
550     return val;
551 }
552 
553 static int
context_tp_contains(PyContext * self,PyObject * key)554 context_tp_contains(PyContext *self, PyObject *key)
555 {
556     if (context_check_key_type(key)) {
557         return -1;
558     }
559     PyObject *val = NULL;
560     return _PyHamt_Find(self->ctx_vars, key, &val);
561 }
562 
563 
564 /*[clinic input]
565 _contextvars.Context.get
566     key: object
567     default: object = None
568     /
569 
570 Return the value for `key` if `key` has the value in the context object.
571 
572 If `key` does not exist, return `default`. If `default` is not given,
573 return None.
574 [clinic start generated code]*/
575 
576 static PyObject *
_contextvars_Context_get_impl(PyContext * self,PyObject * key,PyObject * default_value)577 _contextvars_Context_get_impl(PyContext *self, PyObject *key,
578                               PyObject *default_value)
579 /*[clinic end generated code: output=0c54aa7664268189 input=c8eeb81505023995]*/
580 {
581     if (context_check_key_type(key)) {
582         return NULL;
583     }
584 
585     PyObject *val = NULL;
586     int found = _PyHamt_Find(self->ctx_vars, key, &val);
587     if (found < 0) {
588         return NULL;
589     }
590     if (found == 0) {
591         Py_INCREF(default_value);
592         return default_value;
593     }
594     Py_INCREF(val);
595     return val;
596 }
597 
598 
599 /*[clinic input]
600 _contextvars.Context.items
601 
602 Return all variables and their values in the context object.
603 
604 The result is returned as a list of 2-tuples (variable, value).
605 [clinic start generated code]*/
606 
607 static PyObject *
_contextvars_Context_items_impl(PyContext * self)608 _contextvars_Context_items_impl(PyContext *self)
609 /*[clinic end generated code: output=fa1655c8a08502af input=00db64ae379f9f42]*/
610 {
611     return _PyHamt_NewIterItems(self->ctx_vars);
612 }
613 
614 
615 /*[clinic input]
616 _contextvars.Context.keys
617 
618 Return a list of all variables in the context object.
619 [clinic start generated code]*/
620 
621 static PyObject *
_contextvars_Context_keys_impl(PyContext * self)622 _contextvars_Context_keys_impl(PyContext *self)
623 /*[clinic end generated code: output=177227c6b63ec0e2 input=114b53aebca3449c]*/
624 {
625     return _PyHamt_NewIterKeys(self->ctx_vars);
626 }
627 
628 
629 /*[clinic input]
630 _contextvars.Context.values
631 
632 Return a list of all variables' values in the context object.
633 [clinic start generated code]*/
634 
635 static PyObject *
_contextvars_Context_values_impl(PyContext * self)636 _contextvars_Context_values_impl(PyContext *self)
637 /*[clinic end generated code: output=d286dabfc8db6dde input=ce8075d04a6ea526]*/
638 {
639     return _PyHamt_NewIterValues(self->ctx_vars);
640 }
641 
642 
643 /*[clinic input]
644 _contextvars.Context.copy
645 
646 Return a shallow copy of the context object.
647 [clinic start generated code]*/
648 
649 static PyObject *
_contextvars_Context_copy_impl(PyContext * self)650 _contextvars_Context_copy_impl(PyContext *self)
651 /*[clinic end generated code: output=30ba8896c4707a15 input=ebafdbdd9c72d592]*/
652 {
653     return (PyObject *)context_new_from_vars(self->ctx_vars);
654 }
655 
656 
657 static PyObject *
context_run(PyContext * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)658 context_run(PyContext *self, PyObject *const *args,
659             Py_ssize_t nargs, PyObject *kwnames)
660 {
661     PyThreadState *ts = _PyThreadState_GET();
662 
663     if (nargs < 1) {
664         _PyErr_SetString(ts, PyExc_TypeError,
665                          "run() missing 1 required positional argument");
666         return NULL;
667     }
668 
669     if (_PyContext_Enter(ts, (PyObject *)self)) {
670         return NULL;
671     }
672 
673     PyObject *call_result = _PyObject_VectorcallTstate(
674         ts, args[0], args + 1, nargs - 1, kwnames);
675 
676     if (_PyContext_Exit(ts, (PyObject *)self)) {
677         return NULL;
678     }
679 
680     return call_result;
681 }
682 
683 
684 static PyMethodDef PyContext_methods[] = {
685     _CONTEXTVARS_CONTEXT_GET_METHODDEF
686     _CONTEXTVARS_CONTEXT_ITEMS_METHODDEF
687     _CONTEXTVARS_CONTEXT_KEYS_METHODDEF
688     _CONTEXTVARS_CONTEXT_VALUES_METHODDEF
689     _CONTEXTVARS_CONTEXT_COPY_METHODDEF
690     {"run", _PyCFunction_CAST(context_run), METH_FASTCALL | METH_KEYWORDS, NULL},
691     {NULL, NULL}
692 };
693 
694 static PySequenceMethods PyContext_as_sequence = {
695     0,                                   /* sq_length */
696     0,                                   /* sq_concat */
697     0,                                   /* sq_repeat */
698     0,                                   /* sq_item */
699     0,                                   /* sq_slice */
700     0,                                   /* sq_ass_item */
701     0,                                   /* sq_ass_slice */
702     (objobjproc)context_tp_contains,     /* sq_contains */
703     0,                                   /* sq_inplace_concat */
704     0,                                   /* sq_inplace_repeat */
705 };
706 
707 static PyMappingMethods PyContext_as_mapping = {
708     (lenfunc)context_tp_len,             /* mp_length */
709     (binaryfunc)context_tp_subscript,    /* mp_subscript */
710 };
711 
712 PyTypeObject PyContext_Type = {
713     PyVarObject_HEAD_INIT(&PyType_Type, 0)
714     "_contextvars.Context",
715     sizeof(PyContext),
716     .tp_methods = PyContext_methods,
717     .tp_as_mapping = &PyContext_as_mapping,
718     .tp_as_sequence = &PyContext_as_sequence,
719     .tp_iter = (getiterfunc)context_tp_iter,
720     .tp_dealloc = (destructor)context_tp_dealloc,
721     .tp_getattro = PyObject_GenericGetAttr,
722     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
723     .tp_richcompare = context_tp_richcompare,
724     .tp_traverse = (traverseproc)context_tp_traverse,
725     .tp_clear = (inquiry)context_tp_clear,
726     .tp_new = context_tp_new,
727     .tp_weaklistoffset = offsetof(PyContext, ctx_weakreflist),
728     .tp_hash = PyObject_HashNotImplemented,
729 };
730 
731 
732 /////////////////////////// ContextVar
733 
734 
735 static int
contextvar_set(PyContextVar * var,PyObject * val)736 contextvar_set(PyContextVar *var, PyObject *val)
737 {
738     var->var_cached = NULL;
739     PyThreadState *ts = _PyThreadState_GET();
740 
741     PyContext *ctx = context_get();
742     if (ctx == NULL) {
743         return -1;
744     }
745 
746     PyHamtObject *new_vars = _PyHamt_Assoc(
747         ctx->ctx_vars, (PyObject *)var, val);
748     if (new_vars == NULL) {
749         return -1;
750     }
751 
752     Py_SETREF(ctx->ctx_vars, new_vars);
753 
754     var->var_cached = val;  /* borrow */
755     var->var_cached_tsid = ts->id;
756     var->var_cached_tsver = ts->context_ver;
757     return 0;
758 }
759 
760 static int
contextvar_del(PyContextVar * var)761 contextvar_del(PyContextVar *var)
762 {
763     var->var_cached = NULL;
764 
765     PyContext *ctx = context_get();
766     if (ctx == NULL) {
767         return -1;
768     }
769 
770     PyHamtObject *vars = ctx->ctx_vars;
771     PyHamtObject *new_vars = _PyHamt_Without(vars, (PyObject *)var);
772     if (new_vars == NULL) {
773         return -1;
774     }
775 
776     if (vars == new_vars) {
777         Py_DECREF(new_vars);
778         PyErr_SetObject(PyExc_LookupError, (PyObject *)var);
779         return -1;
780     }
781 
782     Py_SETREF(ctx->ctx_vars, new_vars);
783     return 0;
784 }
785 
786 static Py_hash_t
contextvar_generate_hash(void * addr,PyObject * name)787 contextvar_generate_hash(void *addr, PyObject *name)
788 {
789     /* Take hash of `name` and XOR it with the object's addr.
790 
791        The structure of the tree is encoded in objects' hashes, which
792        means that sufficiently similar hashes would result in tall trees
793        with many Collision nodes.  Which would, in turn, result in slower
794        get and set operations.
795 
796        The XORing helps to ensure that:
797 
798        (1) sequentially allocated ContextVar objects have
799            different hashes;
800 
801        (2) context variables with equal names have
802            different hashes.
803     */
804 
805     Py_hash_t name_hash = PyObject_Hash(name);
806     if (name_hash == -1) {
807         return -1;
808     }
809 
810     Py_hash_t res = _Py_HashPointer(addr) ^ name_hash;
811     return res == -1 ? -2 : res;
812 }
813 
814 static PyContextVar *
contextvar_new(PyObject * name,PyObject * def)815 contextvar_new(PyObject *name, PyObject *def)
816 {
817     if (!PyUnicode_Check(name)) {
818         PyErr_SetString(PyExc_TypeError,
819                         "context variable name must be a str");
820         return NULL;
821     }
822 
823     PyContextVar *var = PyObject_GC_New(PyContextVar, &PyContextVar_Type);
824     if (var == NULL) {
825         return NULL;
826     }
827 
828     var->var_hash = contextvar_generate_hash(var, name);
829     if (var->var_hash == -1) {
830         Py_DECREF(var);
831         return NULL;
832     }
833 
834     Py_INCREF(name);
835     var->var_name = name;
836 
837     Py_XINCREF(def);
838     var->var_default = def;
839 
840     var->var_cached = NULL;
841     var->var_cached_tsid = 0;
842     var->var_cached_tsver = 0;
843 
844     if (_PyObject_GC_MAY_BE_TRACKED(name) ||
845             (def != NULL && _PyObject_GC_MAY_BE_TRACKED(def)))
846     {
847         PyObject_GC_Track(var);
848     }
849     return var;
850 }
851 
852 
853 /*[clinic input]
854 class _contextvars.ContextVar "PyContextVar *" "&PyContextVar_Type"
855 [clinic start generated code]*/
856 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=445da935fa8883c3]*/
857 
858 
859 static PyObject *
contextvar_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)860 contextvar_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
861 {
862     static char *kwlist[] = {"", "default", NULL};
863     PyObject *name;
864     PyObject *def = NULL;
865 
866     if (!PyArg_ParseTupleAndKeywords(
867             args, kwds, "O|$O:ContextVar", kwlist, &name, &def))
868     {
869         return NULL;
870     }
871 
872     return (PyObject *)contextvar_new(name, def);
873 }
874 
875 static int
contextvar_tp_clear(PyContextVar * self)876 contextvar_tp_clear(PyContextVar *self)
877 {
878     Py_CLEAR(self->var_name);
879     Py_CLEAR(self->var_default);
880     self->var_cached = NULL;
881     self->var_cached_tsid = 0;
882     self->var_cached_tsver = 0;
883     return 0;
884 }
885 
886 static int
contextvar_tp_traverse(PyContextVar * self,visitproc visit,void * arg)887 contextvar_tp_traverse(PyContextVar *self, visitproc visit, void *arg)
888 {
889     Py_VISIT(self->var_name);
890     Py_VISIT(self->var_default);
891     return 0;
892 }
893 
894 static void
contextvar_tp_dealloc(PyContextVar * self)895 contextvar_tp_dealloc(PyContextVar *self)
896 {
897     PyObject_GC_UnTrack(self);
898     (void)contextvar_tp_clear(self);
899     Py_TYPE(self)->tp_free(self);
900 }
901 
902 static Py_hash_t
contextvar_tp_hash(PyContextVar * self)903 contextvar_tp_hash(PyContextVar *self)
904 {
905     return self->var_hash;
906 }
907 
908 static PyObject *
contextvar_tp_repr(PyContextVar * self)909 contextvar_tp_repr(PyContextVar *self)
910 {
911     _PyUnicodeWriter writer;
912 
913     _PyUnicodeWriter_Init(&writer);
914 
915     if (_PyUnicodeWriter_WriteASCIIString(
916             &writer, "<ContextVar name=", 17) < 0)
917     {
918         goto error;
919     }
920 
921     PyObject *name = PyObject_Repr(self->var_name);
922     if (name == NULL) {
923         goto error;
924     }
925     if (_PyUnicodeWriter_WriteStr(&writer, name) < 0) {
926         Py_DECREF(name);
927         goto error;
928     }
929     Py_DECREF(name);
930 
931     if (self->var_default != NULL) {
932         if (_PyUnicodeWriter_WriteASCIIString(&writer, " default=", 9) < 0) {
933             goto error;
934         }
935 
936         PyObject *def = PyObject_Repr(self->var_default);
937         if (def == NULL) {
938             goto error;
939         }
940         if (_PyUnicodeWriter_WriteStr(&writer, def) < 0) {
941             Py_DECREF(def);
942             goto error;
943         }
944         Py_DECREF(def);
945     }
946 
947     PyObject *addr = PyUnicode_FromFormat(" at %p>", self);
948     if (addr == NULL) {
949         goto error;
950     }
951     if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) {
952         Py_DECREF(addr);
953         goto error;
954     }
955     Py_DECREF(addr);
956 
957     return _PyUnicodeWriter_Finish(&writer);
958 
959 error:
960     _PyUnicodeWriter_Dealloc(&writer);
961     return NULL;
962 }
963 
964 
965 /*[clinic input]
966 _contextvars.ContextVar.get
967     default: object = NULL
968     /
969 
970 Return a value for the context variable for the current context.
971 
972 If there is no value for the variable in the current context, the method will:
973  * return the value of the default argument of the method, if provided; or
974  * return the default value for the context variable, if it was created
975    with one; or
976  * raise a LookupError.
977 [clinic start generated code]*/
978 
979 static PyObject *
_contextvars_ContextVar_get_impl(PyContextVar * self,PyObject * default_value)980 _contextvars_ContextVar_get_impl(PyContextVar *self, PyObject *default_value)
981 /*[clinic end generated code: output=0746bd0aa2ced7bf input=30aa2ab9e433e401]*/
982 {
983     if (!PyContextVar_CheckExact(self)) {
984         PyErr_SetString(
985             PyExc_TypeError, "an instance of ContextVar was expected");
986         return NULL;
987     }
988 
989     PyObject *val;
990     if (PyContextVar_Get((PyObject *)self, default_value, &val) < 0) {
991         return NULL;
992     }
993 
994     if (val == NULL) {
995         PyErr_SetObject(PyExc_LookupError, (PyObject *)self);
996         return NULL;
997     }
998 
999     return val;
1000 }
1001 
1002 /*[clinic input]
1003 _contextvars.ContextVar.set
1004     value: object
1005     /
1006 
1007 Call to set a new value for the context variable in the current context.
1008 
1009 The required value argument is the new value for the context variable.
1010 
1011 Returns a Token object that can be used to restore the variable to its previous
1012 value via the `ContextVar.reset()` method.
1013 [clinic start generated code]*/
1014 
1015 static PyObject *
_contextvars_ContextVar_set(PyContextVar * self,PyObject * value)1016 _contextvars_ContextVar_set(PyContextVar *self, PyObject *value)
1017 /*[clinic end generated code: output=446ed5e820d6d60b input=c0a6887154227453]*/
1018 {
1019     return PyContextVar_Set((PyObject *)self, value);
1020 }
1021 
1022 /*[clinic input]
1023 _contextvars.ContextVar.reset
1024     token: object
1025     /
1026 
1027 Reset the context variable.
1028 
1029 The variable is reset to the value it had before the `ContextVar.set()` that
1030 created the token was used.
1031 [clinic start generated code]*/
1032 
1033 static PyObject *
_contextvars_ContextVar_reset(PyContextVar * self,PyObject * token)1034 _contextvars_ContextVar_reset(PyContextVar *self, PyObject *token)
1035 /*[clinic end generated code: output=d4ee34d0742d62ee input=ebe2881e5af4ffda]*/
1036 {
1037     if (!PyContextToken_CheckExact(token)) {
1038         PyErr_Format(PyExc_TypeError,
1039                      "expected an instance of Token, got %R", token);
1040         return NULL;
1041     }
1042 
1043     if (PyContextVar_Reset((PyObject *)self, token)) {
1044         return NULL;
1045     }
1046 
1047     Py_RETURN_NONE;
1048 }
1049 
1050 
1051 static PyMemberDef PyContextVar_members[] = {
1052     {"name", T_OBJECT, offsetof(PyContextVar, var_name), READONLY},
1053     {NULL}
1054 };
1055 
1056 static PyMethodDef PyContextVar_methods[] = {
1057     _CONTEXTVARS_CONTEXTVAR_GET_METHODDEF
1058     _CONTEXTVARS_CONTEXTVAR_SET_METHODDEF
1059     _CONTEXTVARS_CONTEXTVAR_RESET_METHODDEF
1060     {"__class_getitem__", Py_GenericAlias,
1061     METH_O|METH_CLASS,       PyDoc_STR("See PEP 585")},
1062     {NULL, NULL}
1063 };
1064 
1065 PyTypeObject PyContextVar_Type = {
1066     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1067     "_contextvars.ContextVar",
1068     sizeof(PyContextVar),
1069     .tp_methods = PyContextVar_methods,
1070     .tp_members = PyContextVar_members,
1071     .tp_dealloc = (destructor)contextvar_tp_dealloc,
1072     .tp_getattro = PyObject_GenericGetAttr,
1073     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
1074     .tp_traverse = (traverseproc)contextvar_tp_traverse,
1075     .tp_clear = (inquiry)contextvar_tp_clear,
1076     .tp_new = contextvar_tp_new,
1077     .tp_free = PyObject_GC_Del,
1078     .tp_hash = (hashfunc)contextvar_tp_hash,
1079     .tp_repr = (reprfunc)contextvar_tp_repr,
1080 };
1081 
1082 
1083 /////////////////////////// Token
1084 
1085 static PyObject * get_token_missing(void);
1086 
1087 
1088 /*[clinic input]
1089 class _contextvars.Token "PyContextToken *" "&PyContextToken_Type"
1090 [clinic start generated code]*/
1091 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=338a5e2db13d3f5b]*/
1092 
1093 
1094 static PyObject *
token_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)1095 token_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
1096 {
1097     PyErr_SetString(PyExc_RuntimeError,
1098                     "Tokens can only be created by ContextVars");
1099     return NULL;
1100 }
1101 
1102 static int
token_tp_clear(PyContextToken * self)1103 token_tp_clear(PyContextToken *self)
1104 {
1105     Py_CLEAR(self->tok_ctx);
1106     Py_CLEAR(self->tok_var);
1107     Py_CLEAR(self->tok_oldval);
1108     return 0;
1109 }
1110 
1111 static int
token_tp_traverse(PyContextToken * self,visitproc visit,void * arg)1112 token_tp_traverse(PyContextToken *self, visitproc visit, void *arg)
1113 {
1114     Py_VISIT(self->tok_ctx);
1115     Py_VISIT(self->tok_var);
1116     Py_VISIT(self->tok_oldval);
1117     return 0;
1118 }
1119 
1120 static void
token_tp_dealloc(PyContextToken * self)1121 token_tp_dealloc(PyContextToken *self)
1122 {
1123     PyObject_GC_UnTrack(self);
1124     (void)token_tp_clear(self);
1125     Py_TYPE(self)->tp_free(self);
1126 }
1127 
1128 static PyObject *
token_tp_repr(PyContextToken * self)1129 token_tp_repr(PyContextToken *self)
1130 {
1131     _PyUnicodeWriter writer;
1132 
1133     _PyUnicodeWriter_Init(&writer);
1134 
1135     if (_PyUnicodeWriter_WriteASCIIString(&writer, "<Token", 6) < 0) {
1136         goto error;
1137     }
1138 
1139     if (self->tok_used) {
1140         if (_PyUnicodeWriter_WriteASCIIString(&writer, " used", 5) < 0) {
1141             goto error;
1142         }
1143     }
1144 
1145     if (_PyUnicodeWriter_WriteASCIIString(&writer, " var=", 5) < 0) {
1146         goto error;
1147     }
1148 
1149     PyObject *var = PyObject_Repr((PyObject *)self->tok_var);
1150     if (var == NULL) {
1151         goto error;
1152     }
1153     if (_PyUnicodeWriter_WriteStr(&writer, var) < 0) {
1154         Py_DECREF(var);
1155         goto error;
1156     }
1157     Py_DECREF(var);
1158 
1159     PyObject *addr = PyUnicode_FromFormat(" at %p>", self);
1160     if (addr == NULL) {
1161         goto error;
1162     }
1163     if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) {
1164         Py_DECREF(addr);
1165         goto error;
1166     }
1167     Py_DECREF(addr);
1168 
1169     return _PyUnicodeWriter_Finish(&writer);
1170 
1171 error:
1172     _PyUnicodeWriter_Dealloc(&writer);
1173     return NULL;
1174 }
1175 
1176 static PyObject *
token_get_var(PyContextToken * self,void * Py_UNUSED (ignored))1177 token_get_var(PyContextToken *self, void *Py_UNUSED(ignored))
1178 {
1179     Py_INCREF(self->tok_var);
1180     return (PyObject *)self->tok_var;
1181 }
1182 
1183 static PyObject *
token_get_old_value(PyContextToken * self,void * Py_UNUSED (ignored))1184 token_get_old_value(PyContextToken *self, void *Py_UNUSED(ignored))
1185 {
1186     if (self->tok_oldval == NULL) {
1187         return get_token_missing();
1188     }
1189 
1190     Py_INCREF(self->tok_oldval);
1191     return self->tok_oldval;
1192 }
1193 
1194 static PyGetSetDef PyContextTokenType_getsetlist[] = {
1195     {"var", (getter)token_get_var, NULL, NULL},
1196     {"old_value", (getter)token_get_old_value, NULL, NULL},
1197     {NULL}
1198 };
1199 
1200 static PyMethodDef PyContextTokenType_methods[] = {
1201     {"__class_getitem__",    Py_GenericAlias,
1202     METH_O|METH_CLASS,       PyDoc_STR("See PEP 585")},
1203     {NULL}
1204 };
1205 
1206 PyTypeObject PyContextToken_Type = {
1207     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1208     "_contextvars.Token",
1209     sizeof(PyContextToken),
1210     .tp_methods = PyContextTokenType_methods,
1211     .tp_getset = PyContextTokenType_getsetlist,
1212     .tp_dealloc = (destructor)token_tp_dealloc,
1213     .tp_getattro = PyObject_GenericGetAttr,
1214     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
1215     .tp_traverse = (traverseproc)token_tp_traverse,
1216     .tp_clear = (inquiry)token_tp_clear,
1217     .tp_new = token_tp_new,
1218     .tp_free = PyObject_GC_Del,
1219     .tp_hash = PyObject_HashNotImplemented,
1220     .tp_repr = (reprfunc)token_tp_repr,
1221 };
1222 
1223 static PyContextToken *
token_new(PyContext * ctx,PyContextVar * var,PyObject * val)1224 token_new(PyContext *ctx, PyContextVar *var, PyObject *val)
1225 {
1226     PyContextToken *tok = PyObject_GC_New(PyContextToken, &PyContextToken_Type);
1227     if (tok == NULL) {
1228         return NULL;
1229     }
1230 
1231     Py_INCREF(ctx);
1232     tok->tok_ctx = ctx;
1233 
1234     Py_INCREF(var);
1235     tok->tok_var = var;
1236 
1237     Py_XINCREF(val);
1238     tok->tok_oldval = val;
1239 
1240     tok->tok_used = 0;
1241 
1242     PyObject_GC_Track(tok);
1243     return tok;
1244 }
1245 
1246 
1247 /////////////////////////// Token.MISSING
1248 
1249 
1250 static PyObject *_token_missing;
1251 
1252 
1253 typedef struct {
1254     PyObject_HEAD
1255 } PyContextTokenMissing;
1256 
1257 
1258 static PyObject *
context_token_missing_tp_repr(PyObject * self)1259 context_token_missing_tp_repr(PyObject *self)
1260 {
1261     return PyUnicode_FromString("<Token.MISSING>");
1262 }
1263 
1264 
1265 PyTypeObject _PyContextTokenMissing_Type = {
1266     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1267     "Token.MISSING",
1268     sizeof(PyContextTokenMissing),
1269     .tp_getattro = PyObject_GenericGetAttr,
1270     .tp_flags = Py_TPFLAGS_DEFAULT,
1271     .tp_repr = context_token_missing_tp_repr,
1272 };
1273 
1274 
1275 static PyObject *
get_token_missing(void)1276 get_token_missing(void)
1277 {
1278     if (_token_missing != NULL) {
1279         Py_INCREF(_token_missing);
1280         return _token_missing;
1281     }
1282 
1283     _token_missing = (PyObject *)PyObject_New(
1284         PyContextTokenMissing, &_PyContextTokenMissing_Type);
1285     if (_token_missing == NULL) {
1286         return NULL;
1287     }
1288 
1289     Py_INCREF(_token_missing);
1290     return _token_missing;
1291 }
1292 
1293 
1294 ///////////////////////////
1295 
1296 
1297 void
_PyContext_ClearFreeList(PyInterpreterState * interp)1298 _PyContext_ClearFreeList(PyInterpreterState *interp)
1299 {
1300 #if PyContext_MAXFREELIST > 0
1301     struct _Py_context_state *state = &interp->context;
1302     for (; state->numfree; state->numfree--) {
1303         PyContext *ctx = state->freelist;
1304         state->freelist = (PyContext *)ctx->ctx_weakreflist;
1305         ctx->ctx_weakreflist = NULL;
1306         PyObject_GC_Del(ctx);
1307     }
1308 #endif
1309 }
1310 
1311 
1312 void
_PyContext_Fini(PyInterpreterState * interp)1313 _PyContext_Fini(PyInterpreterState *interp)
1314 {
1315     if (_Py_IsMainInterpreter(interp)) {
1316         Py_CLEAR(_token_missing);
1317     }
1318     _PyContext_ClearFreeList(interp);
1319 #if defined(Py_DEBUG) && PyContext_MAXFREELIST > 0
1320     struct _Py_context_state *state = &interp->context;
1321     state->numfree = -1;
1322 #endif
1323     _PyHamt_Fini(interp);
1324 }
1325 
1326 
1327 PyStatus
_PyContext_Init(PyInterpreterState * interp)1328 _PyContext_Init(PyInterpreterState *interp)
1329 {
1330     if (!_Py_IsMainInterpreter(interp)) {
1331         return _PyStatus_OK();
1332     }
1333 
1334     PyObject *missing = get_token_missing();
1335     if (PyDict_SetItemString(
1336         PyContextToken_Type.tp_dict, "MISSING", missing))
1337     {
1338         Py_DECREF(missing);
1339         return _PyStatus_ERR("can't init context types");
1340     }
1341     Py_DECREF(missing);
1342 
1343     return _PyStatus_OK();
1344 }
1345