1 /* AST Optimizer */
2 #include "Python.h"
3 #include "pycore_ast.h"           // _PyAST_GetDocString()
4 #include "pycore_compile.h"       // _PyASTOptimizeState
5 #include "pycore_pystate.h"       // _PyThreadState_GET()
6 #include "pycore_format.h"        // F_LJUST
7 
8 
9 static int
make_const(expr_ty node,PyObject * val,PyArena * arena)10 make_const(expr_ty node, PyObject *val, PyArena *arena)
11 {
12     // Even if no new value was calculated, make_const may still
13     // need to clear an error (e.g. for division by zero)
14     if (val == NULL) {
15         if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
16             return 0;
17         }
18         PyErr_Clear();
19         return 1;
20     }
21     if (_PyArena_AddPyObject(arena, val) < 0) {
22         Py_DECREF(val);
23         return 0;
24     }
25     node->kind = Constant_kind;
26     node->v.Constant.kind = NULL;
27     node->v.Constant.value = val;
28     return 1;
29 }
30 
31 #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
32 
33 static int
has_starred(asdl_expr_seq * elts)34 has_starred(asdl_expr_seq *elts)
35 {
36     Py_ssize_t n = asdl_seq_LEN(elts);
37     for (Py_ssize_t i = 0; i < n; i++) {
38         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
39         if (e->kind == Starred_kind) {
40             return 1;
41         }
42     }
43     return 0;
44 }
45 
46 
47 static PyObject*
unary_not(PyObject * v)48 unary_not(PyObject *v)
49 {
50     int r = PyObject_IsTrue(v);
51     if (r < 0)
52         return NULL;
53     return PyBool_FromLong(!r);
54 }
55 
56 static int
fold_unaryop(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)57 fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
58 {
59     expr_ty arg = node->v.UnaryOp.operand;
60 
61     if (arg->kind != Constant_kind) {
62         /* Fold not into comparison */
63         if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
64                 asdl_seq_LEN(arg->v.Compare.ops) == 1) {
65             /* Eq and NotEq are often implemented in terms of one another, so
66                folding not (self == other) into self != other breaks implementation
67                of !=. Detecting such cases doesn't seem worthwhile.
68                Python uses </> for 'is subset'/'is superset' operations on sets.
69                They don't satisfy not folding laws. */
70             cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0);
71             switch (op) {
72             case Is:
73                 op = IsNot;
74                 break;
75             case IsNot:
76                 op = Is;
77                 break;
78             case In:
79                 op = NotIn;
80                 break;
81             case NotIn:
82                 op = In;
83                 break;
84             // The remaining comparison operators can't be safely inverted
85             case Eq:
86             case NotEq:
87             case Lt:
88             case LtE:
89             case Gt:
90             case GtE:
91                 op = 0; // The AST enums leave "0" free as an "unused" marker
92                 break;
93             // No default case, so the compiler will emit a warning if new
94             // comparison operators are added without being handled here
95             }
96             if (op) {
97                 asdl_seq_SET(arg->v.Compare.ops, 0, op);
98                 COPY_NODE(node, arg);
99                 return 1;
100             }
101         }
102         return 1;
103     }
104 
105     typedef PyObject *(*unary_op)(PyObject*);
106     static const unary_op ops[] = {
107         [Invert] = PyNumber_Invert,
108         [Not] = unary_not,
109         [UAdd] = PyNumber_Positive,
110         [USub] = PyNumber_Negative,
111     };
112     PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
113     return make_const(node, newval, arena);
114 }
115 
116 /* Check whether a collection doesn't containing too much items (including
117    subcollections).  This protects from creating a constant that needs
118    too much time for calculating a hash.
119    "limit" is the maximal number of items.
120    Returns the negative number if the total number of items exceeds the
121    limit.  Otherwise returns the limit minus the total number of items.
122 */
123 
124 static Py_ssize_t
check_complexity(PyObject * obj,Py_ssize_t limit)125 check_complexity(PyObject *obj, Py_ssize_t limit)
126 {
127     if (PyTuple_Check(obj)) {
128         Py_ssize_t i;
129         limit -= PyTuple_GET_SIZE(obj);
130         for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
131             limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
132         }
133         return limit;
134     }
135     else if (PyFrozenSet_Check(obj)) {
136         Py_ssize_t i = 0;
137         PyObject *item;
138         Py_hash_t hash;
139         limit -= PySet_GET_SIZE(obj);
140         while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
141             limit = check_complexity(item, limit);
142         }
143     }
144     return limit;
145 }
146 
147 #define MAX_INT_SIZE           128  /* bits */
148 #define MAX_COLLECTION_SIZE    256  /* items */
149 #define MAX_STR_SIZE          4096  /* characters */
150 #define MAX_TOTAL_ITEMS       1024  /* including nested collections */
151 
152 static PyObject *
safe_multiply(PyObject * v,PyObject * w)153 safe_multiply(PyObject *v, PyObject *w)
154 {
155     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
156         size_t vbits = _PyLong_NumBits(v);
157         size_t wbits = _PyLong_NumBits(w);
158         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
159             return NULL;
160         }
161         if (vbits + wbits > MAX_INT_SIZE) {
162             return NULL;
163         }
164     }
165     else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
166         Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
167                                              PySet_GET_SIZE(w);
168         if (size) {
169             long n = PyLong_AsLong(v);
170             if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
171                 return NULL;
172             }
173             if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
174                 return NULL;
175             }
176         }
177     }
178     else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
179         Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
180                                                PyBytes_GET_SIZE(w);
181         if (size) {
182             long n = PyLong_AsLong(v);
183             if (n < 0 || n > MAX_STR_SIZE / size) {
184                 return NULL;
185             }
186         }
187     }
188     else if (PyLong_Check(w) &&
189              (PyTuple_Check(v) || PyFrozenSet_Check(v) ||
190               PyUnicode_Check(v) || PyBytes_Check(v)))
191     {
192         return safe_multiply(w, v);
193     }
194 
195     return PyNumber_Multiply(v, w);
196 }
197 
198 static PyObject *
safe_power(PyObject * v,PyObject * w)199 safe_power(PyObject *v, PyObject *w)
200 {
201     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w) > 0) {
202         size_t vbits = _PyLong_NumBits(v);
203         size_t wbits = PyLong_AsSize_t(w);
204         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
205             return NULL;
206         }
207         if (vbits > MAX_INT_SIZE / wbits) {
208             return NULL;
209         }
210     }
211 
212     return PyNumber_Power(v, w, Py_None);
213 }
214 
215 static PyObject *
safe_lshift(PyObject * v,PyObject * w)216 safe_lshift(PyObject *v, PyObject *w)
217 {
218     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
219         size_t vbits = _PyLong_NumBits(v);
220         size_t wbits = PyLong_AsSize_t(w);
221         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
222             return NULL;
223         }
224         if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
225             return NULL;
226         }
227     }
228 
229     return PyNumber_Lshift(v, w);
230 }
231 
232 static PyObject *
safe_mod(PyObject * v,PyObject * w)233 safe_mod(PyObject *v, PyObject *w)
234 {
235     if (PyUnicode_Check(v) || PyBytes_Check(v)) {
236         return NULL;
237     }
238 
239     return PyNumber_Remainder(v, w);
240 }
241 
242 
243 static expr_ty
parse_literal(PyObject * fmt,Py_ssize_t * ppos,PyArena * arena)244 parse_literal(PyObject *fmt, Py_ssize_t *ppos, PyArena *arena)
245 {
246     const void *data = PyUnicode_DATA(fmt);
247     int kind = PyUnicode_KIND(fmt);
248     Py_ssize_t size = PyUnicode_GET_LENGTH(fmt);
249     Py_ssize_t start, pos;
250     int has_percents = 0;
251     start = pos = *ppos;
252     while (pos < size) {
253         if (PyUnicode_READ(kind, data, pos) != '%') {
254             pos++;
255         }
256         else if (pos+1 < size && PyUnicode_READ(kind, data, pos+1) == '%') {
257             has_percents = 1;
258             pos += 2;
259         }
260         else {
261             break;
262         }
263     }
264     *ppos = pos;
265     if (pos == start) {
266         return NULL;
267     }
268     PyObject *str = PyUnicode_Substring(fmt, start, pos);
269     /* str = str.replace('%%', '%') */
270     if (str && has_percents) {
271         _Py_DECLARE_STR(percent, "%");
272         _Py_DECLARE_STR(dbl_percent, "%%");
273         Py_SETREF(str, PyUnicode_Replace(str, &_Py_STR(dbl_percent),
274                                          &_Py_STR(percent), -1));
275     }
276     if (!str) {
277         return NULL;
278     }
279 
280     if (_PyArena_AddPyObject(arena, str) < 0) {
281         Py_DECREF(str);
282         return NULL;
283     }
284     return _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
285 }
286 
287 #define MAXDIGITS 3
288 
289 static int
simple_format_arg_parse(PyObject * fmt,Py_ssize_t * ppos,int * spec,int * flags,int * width,int * prec)290 simple_format_arg_parse(PyObject *fmt, Py_ssize_t *ppos,
291                         int *spec, int *flags, int *width, int *prec)
292 {
293     Py_ssize_t pos = *ppos, len = PyUnicode_GET_LENGTH(fmt);
294     Py_UCS4 ch;
295 
296 #define NEXTC do {                      \
297     if (pos >= len) {                   \
298         return 0;                       \
299     }                                   \
300     ch = PyUnicode_READ_CHAR(fmt, pos); \
301     pos++;                              \
302 } while (0)
303 
304     *flags = 0;
305     while (1) {
306         NEXTC;
307         switch (ch) {
308             case '-': *flags |= F_LJUST; continue;
309             case '+': *flags |= F_SIGN; continue;
310             case ' ': *flags |= F_BLANK; continue;
311             case '#': *flags |= F_ALT; continue;
312             case '0': *flags |= F_ZERO; continue;
313         }
314         break;
315     }
316     if ('0' <= ch && ch <= '9') {
317         *width = 0;
318         int digits = 0;
319         while ('0' <= ch && ch <= '9') {
320             *width = *width * 10 + (ch - '0');
321             NEXTC;
322             if (++digits >= MAXDIGITS) {
323                 return 0;
324             }
325         }
326     }
327 
328     if (ch == '.') {
329         NEXTC;
330         *prec = 0;
331         if ('0' <= ch && ch <= '9') {
332             int digits = 0;
333             while ('0' <= ch && ch <= '9') {
334                 *prec = *prec * 10 + (ch - '0');
335                 NEXTC;
336                 if (++digits >= MAXDIGITS) {
337                     return 0;
338                 }
339             }
340         }
341     }
342     *spec = ch;
343     *ppos = pos;
344     return 1;
345 
346 #undef NEXTC
347 }
348 
349 static expr_ty
parse_format(PyObject * fmt,Py_ssize_t * ppos,expr_ty arg,PyArena * arena)350 parse_format(PyObject *fmt, Py_ssize_t *ppos, expr_ty arg, PyArena *arena)
351 {
352     int spec, flags, width = -1, prec = -1;
353     if (!simple_format_arg_parse(fmt, ppos, &spec, &flags, &width, &prec)) {
354         // Unsupported format.
355         return NULL;
356     }
357     if (spec == 's' || spec == 'r' || spec == 'a') {
358         char buf[1 + MAXDIGITS + 1 + MAXDIGITS + 1], *p = buf;
359         if (!(flags & F_LJUST) && width > 0) {
360             *p++ = '>';
361         }
362         if (width >= 0) {
363             p += snprintf(p, MAXDIGITS + 1, "%d", width);
364         }
365         if (prec >= 0) {
366             p += snprintf(p, MAXDIGITS + 2, ".%d", prec);
367         }
368         expr_ty format_spec = NULL;
369         if (p != buf) {
370             PyObject *str = PyUnicode_FromString(buf);
371             if (str == NULL) {
372                 return NULL;
373             }
374             if (_PyArena_AddPyObject(arena, str) < 0) {
375                 Py_DECREF(str);
376                 return NULL;
377             }
378             format_spec = _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
379             if (format_spec == NULL) {
380                 return NULL;
381             }
382         }
383         return _PyAST_FormattedValue(arg, spec, format_spec,
384                                      arg->lineno, arg->col_offset,
385                                      arg->end_lineno, arg->end_col_offset,
386                                      arena);
387     }
388     // Unsupported format.
389     return NULL;
390 }
391 
392 static int
optimize_format(expr_ty node,PyObject * fmt,asdl_expr_seq * elts,PyArena * arena)393 optimize_format(expr_ty node, PyObject *fmt, asdl_expr_seq *elts, PyArena *arena)
394 {
395     Py_ssize_t pos = 0;
396     Py_ssize_t cnt = 0;
397     asdl_expr_seq *seq = _Py_asdl_expr_seq_new(asdl_seq_LEN(elts) * 2 + 1, arena);
398     if (!seq) {
399         return 0;
400     }
401     seq->size = 0;
402 
403     while (1) {
404         expr_ty lit = parse_literal(fmt, &pos, arena);
405         if (lit) {
406             asdl_seq_SET(seq, seq->size++, lit);
407         }
408         else if (PyErr_Occurred()) {
409             return 0;
410         }
411 
412         if (pos >= PyUnicode_GET_LENGTH(fmt)) {
413             break;
414         }
415         if (cnt >= asdl_seq_LEN(elts)) {
416             // More format units than items.
417             return 1;
418         }
419         assert(PyUnicode_READ_CHAR(fmt, pos) == '%');
420         pos++;
421         expr_ty expr = parse_format(fmt, &pos, asdl_seq_GET(elts, cnt), arena);
422         cnt++;
423         if (!expr) {
424             return !PyErr_Occurred();
425         }
426         asdl_seq_SET(seq, seq->size++, expr);
427     }
428     if (cnt < asdl_seq_LEN(elts)) {
429         // More items than format units.
430         return 1;
431     }
432     expr_ty res = _PyAST_JoinedStr(seq,
433                                    node->lineno, node->col_offset,
434                                    node->end_lineno, node->end_col_offset,
435                                    arena);
436     if (!res) {
437         return 0;
438     }
439     COPY_NODE(node, res);
440 //     PySys_FormatStderr("format = %R\n", fmt);
441     return 1;
442 }
443 
444 static int
fold_binop(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)445 fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
446 {
447     expr_ty lhs, rhs;
448     lhs = node->v.BinOp.left;
449     rhs = node->v.BinOp.right;
450     if (lhs->kind != Constant_kind) {
451         return 1;
452     }
453     PyObject *lv = lhs->v.Constant.value;
454 
455     if (node->v.BinOp.op == Mod &&
456         rhs->kind == Tuple_kind &&
457         PyUnicode_Check(lv) &&
458         !has_starred(rhs->v.Tuple.elts))
459     {
460         return optimize_format(node, lv, rhs->v.Tuple.elts, arena);
461     }
462 
463     if (rhs->kind != Constant_kind) {
464         return 1;
465     }
466 
467     PyObject *rv = rhs->v.Constant.value;
468     PyObject *newval = NULL;
469 
470     switch (node->v.BinOp.op) {
471     case Add:
472         newval = PyNumber_Add(lv, rv);
473         break;
474     case Sub:
475         newval = PyNumber_Subtract(lv, rv);
476         break;
477     case Mult:
478         newval = safe_multiply(lv, rv);
479         break;
480     case Div:
481         newval = PyNumber_TrueDivide(lv, rv);
482         break;
483     case FloorDiv:
484         newval = PyNumber_FloorDivide(lv, rv);
485         break;
486     case Mod:
487         newval = safe_mod(lv, rv);
488         break;
489     case Pow:
490         newval = safe_power(lv, rv);
491         break;
492     case LShift:
493         newval = safe_lshift(lv, rv);
494         break;
495     case RShift:
496         newval = PyNumber_Rshift(lv, rv);
497         break;
498     case BitOr:
499         newval = PyNumber_Or(lv, rv);
500         break;
501     case BitXor:
502         newval = PyNumber_Xor(lv, rv);
503         break;
504     case BitAnd:
505         newval = PyNumber_And(lv, rv);
506         break;
507     // No builtin constants implement the following operators
508     case MatMult:
509         return 1;
510     // No default case, so the compiler will emit a warning if new binary
511     // operators are added without being handled here
512     }
513 
514     return make_const(node, newval, arena);
515 }
516 
517 static PyObject*
make_const_tuple(asdl_expr_seq * elts)518 make_const_tuple(asdl_expr_seq *elts)
519 {
520     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
521         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
522         if (e->kind != Constant_kind) {
523             return NULL;
524         }
525     }
526 
527     PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
528     if (newval == NULL) {
529         return NULL;
530     }
531 
532     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
533         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
534         PyObject *v = e->v.Constant.value;
535         Py_INCREF(v);
536         PyTuple_SET_ITEM(newval, i, v);
537     }
538     return newval;
539 }
540 
541 static int
fold_tuple(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)542 fold_tuple(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
543 {
544     PyObject *newval;
545 
546     if (node->v.Tuple.ctx != Load)
547         return 1;
548 
549     newval = make_const_tuple(node->v.Tuple.elts);
550     return make_const(node, newval, arena);
551 }
552 
553 static int
fold_subscr(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)554 fold_subscr(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
555 {
556     PyObject *newval;
557     expr_ty arg, idx;
558 
559     arg = node->v.Subscript.value;
560     idx = node->v.Subscript.slice;
561     if (node->v.Subscript.ctx != Load ||
562             arg->kind != Constant_kind ||
563             idx->kind != Constant_kind)
564     {
565         return 1;
566     }
567 
568     newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value);
569     return make_const(node, newval, arena);
570 }
571 
572 /* Change literal list or set of constants into constant
573    tuple or frozenset respectively.  Change literal list of
574    non-constants into tuple.
575    Used for right operand of "in" and "not in" tests and for iterable
576    in "for" loop and comprehensions.
577 */
578 static int
fold_iter(expr_ty arg,PyArena * arena,_PyASTOptimizeState * state)579 fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state)
580 {
581     PyObject *newval;
582     if (arg->kind == List_kind) {
583         /* First change a list into tuple. */
584         asdl_expr_seq *elts = arg->v.List.elts;
585         if (has_starred(elts)) {
586             return 1;
587         }
588         expr_context_ty ctx = arg->v.List.ctx;
589         arg->kind = Tuple_kind;
590         arg->v.Tuple.elts = elts;
591         arg->v.Tuple.ctx = ctx;
592         /* Try to create a constant tuple. */
593         newval = make_const_tuple(elts);
594     }
595     else if (arg->kind == Set_kind) {
596         newval = make_const_tuple(arg->v.Set.elts);
597         if (newval) {
598             Py_SETREF(newval, PyFrozenSet_New(newval));
599         }
600     }
601     else {
602         return 1;
603     }
604     return make_const(arg, newval, arena);
605 }
606 
607 static int
fold_compare(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)608 fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
609 {
610     asdl_int_seq *ops;
611     asdl_expr_seq *args;
612     Py_ssize_t i;
613 
614     ops = node->v.Compare.ops;
615     args = node->v.Compare.comparators;
616     /* Change literal list or set in 'in' or 'not in' into
617        tuple or frozenset respectively. */
618     i = asdl_seq_LEN(ops) - 1;
619     int op = asdl_seq_GET(ops, i);
620     if (op == In || op == NotIn) {
621         if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) {
622             return 0;
623         }
624     }
625     return 1;
626 }
627 
628 static int astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
629 static int astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
630 static int astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
631 static int astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
632 static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
633 static int astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
634 static int astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
635 static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
636 static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
637 static int astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
638 static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
639 
640 #define CALL(FUNC, TYPE, ARG) \
641     if (!FUNC((ARG), ctx_, state)) \
642         return 0;
643 
644 #define CALL_OPT(FUNC, TYPE, ARG) \
645     if ((ARG) != NULL && !FUNC((ARG), ctx_, state)) \
646         return 0;
647 
648 #define CALL_SEQ(FUNC, TYPE, ARG) { \
649     int i; \
650     asdl_ ## TYPE ## _seq *seq = (ARG); /* avoid variable capture */ \
651     for (i = 0; i < asdl_seq_LEN(seq); i++) { \
652         TYPE ## _ty elt = (TYPE ## _ty)asdl_seq_GET(seq, i); \
653         if (elt != NULL && !FUNC(elt, ctx_, state)) \
654             return 0; \
655     } \
656 }
657 
658 
659 static int
astfold_body(asdl_stmt_seq * stmts,PyArena * ctx_,_PyASTOptimizeState * state)660 astfold_body(asdl_stmt_seq *stmts, PyArena *ctx_, _PyASTOptimizeState *state)
661 {
662     int docstring = _PyAST_GetDocString(stmts) != NULL;
663     CALL_SEQ(astfold_stmt, stmt, stmts);
664     if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
665         stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
666         asdl_expr_seq *values = _Py_asdl_expr_seq_new(1, ctx_);
667         if (!values) {
668             return 0;
669         }
670         asdl_seq_SET(values, 0, st->v.Expr.value);
671         expr_ty expr = _PyAST_JoinedStr(values, st->lineno, st->col_offset,
672                                         st->end_lineno, st->end_col_offset,
673                                         ctx_);
674         if (!expr) {
675             return 0;
676         }
677         st->v.Expr.value = expr;
678     }
679     return 1;
680 }
681 
682 static int
astfold_mod(mod_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)683 astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
684 {
685     switch (node_->kind) {
686     case Module_kind:
687         CALL(astfold_body, asdl_seq, node_->v.Module.body);
688         break;
689     case Interactive_kind:
690         CALL_SEQ(astfold_stmt, stmt, node_->v.Interactive.body);
691         break;
692     case Expression_kind:
693         CALL(astfold_expr, expr_ty, node_->v.Expression.body);
694         break;
695     // The following top level nodes don't participate in constant folding
696     case FunctionType_kind:
697         break;
698     // No default case, so the compiler will emit a warning if new top level
699     // compilation nodes are added without being handled here
700     }
701     return 1;
702 }
703 
704 static int
astfold_expr(expr_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)705 astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
706 {
707     if (++state->recursion_depth > state->recursion_limit) {
708         PyErr_SetString(PyExc_RecursionError,
709                         "maximum recursion depth exceeded during compilation");
710         return 0;
711     }
712     switch (node_->kind) {
713     case BoolOp_kind:
714         CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
715         break;
716     case BinOp_kind:
717         CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
718         CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
719         CALL(fold_binop, expr_ty, node_);
720         break;
721     case UnaryOp_kind:
722         CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
723         CALL(fold_unaryop, expr_ty, node_);
724         break;
725     case Lambda_kind:
726         CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
727         CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
728         break;
729     case IfExp_kind:
730         CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
731         CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
732         CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
733         break;
734     case Dict_kind:
735         CALL_SEQ(astfold_expr, expr, node_->v.Dict.keys);
736         CALL_SEQ(astfold_expr, expr, node_->v.Dict.values);
737         break;
738     case Set_kind:
739         CALL_SEQ(astfold_expr, expr, node_->v.Set.elts);
740         break;
741     case ListComp_kind:
742         CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
743         CALL_SEQ(astfold_comprehension, comprehension, node_->v.ListComp.generators);
744         break;
745     case SetComp_kind:
746         CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
747         CALL_SEQ(astfold_comprehension, comprehension, node_->v.SetComp.generators);
748         break;
749     case DictComp_kind:
750         CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
751         CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
752         CALL_SEQ(astfold_comprehension, comprehension, node_->v.DictComp.generators);
753         break;
754     case GeneratorExp_kind:
755         CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
756         CALL_SEQ(astfold_comprehension, comprehension, node_->v.GeneratorExp.generators);
757         break;
758     case Await_kind:
759         CALL(astfold_expr, expr_ty, node_->v.Await.value);
760         break;
761     case Yield_kind:
762         CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
763         break;
764     case YieldFrom_kind:
765         CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
766         break;
767     case Compare_kind:
768         CALL(astfold_expr, expr_ty, node_->v.Compare.left);
769         CALL_SEQ(astfold_expr, expr, node_->v.Compare.comparators);
770         CALL(fold_compare, expr_ty, node_);
771         break;
772     case Call_kind:
773         CALL(astfold_expr, expr_ty, node_->v.Call.func);
774         CALL_SEQ(astfold_expr, expr, node_->v.Call.args);
775         CALL_SEQ(astfold_keyword, keyword, node_->v.Call.keywords);
776         break;
777     case FormattedValue_kind:
778         CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
779         CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
780         break;
781     case JoinedStr_kind:
782         CALL_SEQ(astfold_expr, expr, node_->v.JoinedStr.values);
783         break;
784     case Attribute_kind:
785         CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
786         break;
787     case Subscript_kind:
788         CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
789         CALL(astfold_expr, expr_ty, node_->v.Subscript.slice);
790         CALL(fold_subscr, expr_ty, node_);
791         break;
792     case Starred_kind:
793         CALL(astfold_expr, expr_ty, node_->v.Starred.value);
794         break;
795     case Slice_kind:
796         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
797         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
798         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
799         break;
800     case List_kind:
801         CALL_SEQ(astfold_expr, expr, node_->v.List.elts);
802         break;
803     case Tuple_kind:
804         CALL_SEQ(astfold_expr, expr, node_->v.Tuple.elts);
805         CALL(fold_tuple, expr_ty, node_);
806         break;
807     case Name_kind:
808         if (node_->v.Name.ctx == Load &&
809                 _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
810             state->recursion_depth--;
811             return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
812         }
813         break;
814     case NamedExpr_kind:
815         CALL(astfold_expr, expr_ty, node_->v.NamedExpr.value);
816         break;
817     case Constant_kind:
818         // Already a constant, nothing further to do
819         break;
820     // No default case, so the compiler will emit a warning if new expression
821     // kinds are added without being handled here
822     }
823     state->recursion_depth--;
824     return 1;
825 }
826 
827 static int
astfold_keyword(keyword_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)828 astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
829 {
830     CALL(astfold_expr, expr_ty, node_->value);
831     return 1;
832 }
833 
834 static int
astfold_comprehension(comprehension_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)835 astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
836 {
837     CALL(astfold_expr, expr_ty, node_->target);
838     CALL(astfold_expr, expr_ty, node_->iter);
839     CALL_SEQ(astfold_expr, expr, node_->ifs);
840 
841     CALL(fold_iter, expr_ty, node_->iter);
842     return 1;
843 }
844 
845 static int
astfold_arguments(arguments_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)846 astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
847 {
848     CALL_SEQ(astfold_arg, arg, node_->posonlyargs);
849     CALL_SEQ(astfold_arg, arg, node_->args);
850     CALL_OPT(astfold_arg, arg_ty, node_->vararg);
851     CALL_SEQ(astfold_arg, arg, node_->kwonlyargs);
852     CALL_SEQ(astfold_expr, expr, node_->kw_defaults);
853     CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
854     CALL_SEQ(astfold_expr, expr, node_->defaults);
855     return 1;
856 }
857 
858 static int
astfold_arg(arg_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)859 astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
860 {
861     if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
862         CALL_OPT(astfold_expr, expr_ty, node_->annotation);
863     }
864     return 1;
865 }
866 
867 static int
astfold_stmt(stmt_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)868 astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
869 {
870     if (++state->recursion_depth > state->recursion_limit) {
871         PyErr_SetString(PyExc_RecursionError,
872                         "maximum recursion depth exceeded during compilation");
873         return 0;
874     }
875     switch (node_->kind) {
876     case FunctionDef_kind:
877         CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
878         CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
879         CALL_SEQ(astfold_expr, expr, node_->v.FunctionDef.decorator_list);
880         if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
881             CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
882         }
883         break;
884     case AsyncFunctionDef_kind:
885         CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
886         CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
887         CALL_SEQ(astfold_expr, expr, node_->v.AsyncFunctionDef.decorator_list);
888         if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
889             CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
890         }
891         break;
892     case ClassDef_kind:
893         CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.bases);
894         CALL_SEQ(astfold_keyword, keyword, node_->v.ClassDef.keywords);
895         CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
896         CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.decorator_list);
897         break;
898     case Return_kind:
899         CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
900         break;
901     case Delete_kind:
902         CALL_SEQ(astfold_expr, expr, node_->v.Delete.targets);
903         break;
904     case Assign_kind:
905         CALL_SEQ(astfold_expr, expr, node_->v.Assign.targets);
906         CALL(astfold_expr, expr_ty, node_->v.Assign.value);
907         break;
908     case AugAssign_kind:
909         CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
910         CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
911         break;
912     case AnnAssign_kind:
913         CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
914         if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
915             CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
916         }
917         CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
918         break;
919     case For_kind:
920         CALL(astfold_expr, expr_ty, node_->v.For.target);
921         CALL(astfold_expr, expr_ty, node_->v.For.iter);
922         CALL_SEQ(astfold_stmt, stmt, node_->v.For.body);
923         CALL_SEQ(astfold_stmt, stmt, node_->v.For.orelse);
924 
925         CALL(fold_iter, expr_ty, node_->v.For.iter);
926         break;
927     case AsyncFor_kind:
928         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
929         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
930         CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.body);
931         CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.orelse);
932         break;
933     case While_kind:
934         CALL(astfold_expr, expr_ty, node_->v.While.test);
935         CALL_SEQ(astfold_stmt, stmt, node_->v.While.body);
936         CALL_SEQ(astfold_stmt, stmt, node_->v.While.orelse);
937         break;
938     case If_kind:
939         CALL(astfold_expr, expr_ty, node_->v.If.test);
940         CALL_SEQ(astfold_stmt, stmt, node_->v.If.body);
941         CALL_SEQ(astfold_stmt, stmt, node_->v.If.orelse);
942         break;
943     case With_kind:
944         CALL_SEQ(astfold_withitem, withitem, node_->v.With.items);
945         CALL_SEQ(astfold_stmt, stmt, node_->v.With.body);
946         break;
947     case AsyncWith_kind:
948         CALL_SEQ(astfold_withitem, withitem, node_->v.AsyncWith.items);
949         CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncWith.body);
950         break;
951     case Raise_kind:
952         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
953         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
954         break;
955     case Try_kind:
956         CALL_SEQ(astfold_stmt, stmt, node_->v.Try.body);
957         CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.Try.handlers);
958         CALL_SEQ(astfold_stmt, stmt, node_->v.Try.orelse);
959         CALL_SEQ(astfold_stmt, stmt, node_->v.Try.finalbody);
960         break;
961     case TryStar_kind:
962         CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.body);
963         CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.TryStar.handlers);
964         CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.orelse);
965         CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.finalbody);
966         break;
967     case Assert_kind:
968         CALL(astfold_expr, expr_ty, node_->v.Assert.test);
969         CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
970         break;
971     case Expr_kind:
972         CALL(astfold_expr, expr_ty, node_->v.Expr.value);
973         break;
974     case Match_kind:
975         CALL(astfold_expr, expr_ty, node_->v.Match.subject);
976         CALL_SEQ(astfold_match_case, match_case, node_->v.Match.cases);
977         break;
978     // The following statements don't contain any subexpressions to be folded
979     case Import_kind:
980     case ImportFrom_kind:
981     case Global_kind:
982     case Nonlocal_kind:
983     case Pass_kind:
984     case Break_kind:
985     case Continue_kind:
986         break;
987     // No default case, so the compiler will emit a warning if new statement
988     // kinds are added without being handled here
989     }
990     state->recursion_depth--;
991     return 1;
992 }
993 
994 static int
astfold_excepthandler(excepthandler_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)995 astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
996 {
997     switch (node_->kind) {
998     case ExceptHandler_kind:
999         CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
1000         CALL_SEQ(astfold_stmt, stmt, node_->v.ExceptHandler.body);
1001         break;
1002     // No default case, so the compiler will emit a warning if new handler
1003     // kinds are added without being handled here
1004     }
1005     return 1;
1006 }
1007 
1008 static int
astfold_withitem(withitem_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1009 astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1010 {
1011     CALL(astfold_expr, expr_ty, node_->context_expr);
1012     CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
1013     return 1;
1014 }
1015 
1016 static int
astfold_pattern(pattern_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1017 astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1018 {
1019     // Currently, this is really only used to form complex/negative numeric
1020     // constants in MatchValue and MatchMapping nodes
1021     // We still recurse into all subexpressions and subpatterns anyway
1022     if (++state->recursion_depth > state->recursion_limit) {
1023         PyErr_SetString(PyExc_RecursionError,
1024                         "maximum recursion depth exceeded during compilation");
1025         return 0;
1026     }
1027     switch (node_->kind) {
1028         case MatchValue_kind:
1029             CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
1030             break;
1031         case MatchSingleton_kind:
1032             break;
1033         case MatchSequence_kind:
1034             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns);
1035             break;
1036         case MatchMapping_kind:
1037             CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys);
1038             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns);
1039             break;
1040         case MatchClass_kind:
1041             CALL(astfold_expr, expr_ty, node_->v.MatchClass.cls);
1042             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.patterns);
1043             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.kwd_patterns);
1044             break;
1045         case MatchStar_kind:
1046             break;
1047         case MatchAs_kind:
1048             if (node_->v.MatchAs.pattern) {
1049                 CALL(astfold_pattern, pattern_ty, node_->v.MatchAs.pattern);
1050             }
1051             break;
1052         case MatchOr_kind:
1053             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchOr.patterns);
1054             break;
1055     // No default case, so the compiler will emit a warning if new pattern
1056     // kinds are added without being handled here
1057     }
1058     state->recursion_depth--;
1059     return 1;
1060 }
1061 
1062 static int
astfold_match_case(match_case_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1063 astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1064 {
1065     CALL(astfold_pattern, expr_ty, node_->pattern);
1066     CALL_OPT(astfold_expr, expr_ty, node_->guard);
1067     CALL_SEQ(astfold_stmt, stmt, node_->body);
1068     return 1;
1069 }
1070 
1071 #undef CALL
1072 #undef CALL_OPT
1073 #undef CALL_SEQ
1074 
1075 /* See comments in symtable.c. */
1076 #define COMPILER_STACK_FRAME_SCALE 3
1077 
1078 int
_PyAST_Optimize(mod_ty mod,PyArena * arena,_PyASTOptimizeState * state)1079 _PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
1080 {
1081     PyThreadState *tstate;
1082     int recursion_limit = Py_GetRecursionLimit();
1083     int starting_recursion_depth;
1084 
1085     /* Setup recursion depth check counters */
1086     tstate = _PyThreadState_GET();
1087     if (!tstate) {
1088         return 0;
1089     }
1090     /* Be careful here to prevent overflow. */
1091     int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
1092     starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1093         recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
1094     state->recursion_depth = starting_recursion_depth;
1095     state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1096         recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
1097 
1098     int ret = astfold_mod(mod, arena, state);
1099     assert(ret || PyErr_Occurred());
1100 
1101     /* Check that the recursion depth counting balanced correctly */
1102     if (ret && state->recursion_depth != starting_recursion_depth) {
1103         PyErr_Format(PyExc_SystemError,
1104             "AST optimizer recursion depth mismatch (before=%d, after=%d)",
1105             starting_recursion_depth, state->recursion_depth);
1106         return 0;
1107     }
1108 
1109     return ret;
1110 }
1111