1 /* AST Optimizer */
2 #include "Python.h"
3 #include "pycore_ast.h" // _PyAST_GetDocString()
4 #include "pycore_compile.h" // _PyASTOptimizeState
5 #include "pycore_pystate.h" // _PyThreadState_GET()
6 #include "pycore_format.h" // F_LJUST
7
8
9 static int
make_const(expr_ty node,PyObject * val,PyArena * arena)10 make_const(expr_ty node, PyObject *val, PyArena *arena)
11 {
12 // Even if no new value was calculated, make_const may still
13 // need to clear an error (e.g. for division by zero)
14 if (val == NULL) {
15 if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
16 return 0;
17 }
18 PyErr_Clear();
19 return 1;
20 }
21 if (_PyArena_AddPyObject(arena, val) < 0) {
22 Py_DECREF(val);
23 return 0;
24 }
25 node->kind = Constant_kind;
26 node->v.Constant.kind = NULL;
27 node->v.Constant.value = val;
28 return 1;
29 }
30
31 #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
32
33 static int
has_starred(asdl_expr_seq * elts)34 has_starred(asdl_expr_seq *elts)
35 {
36 Py_ssize_t n = asdl_seq_LEN(elts);
37 for (Py_ssize_t i = 0; i < n; i++) {
38 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
39 if (e->kind == Starred_kind) {
40 return 1;
41 }
42 }
43 return 0;
44 }
45
46
47 static PyObject*
unary_not(PyObject * v)48 unary_not(PyObject *v)
49 {
50 int r = PyObject_IsTrue(v);
51 if (r < 0)
52 return NULL;
53 return PyBool_FromLong(!r);
54 }
55
56 static int
fold_unaryop(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)57 fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
58 {
59 expr_ty arg = node->v.UnaryOp.operand;
60
61 if (arg->kind != Constant_kind) {
62 /* Fold not into comparison */
63 if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
64 asdl_seq_LEN(arg->v.Compare.ops) == 1) {
65 /* Eq and NotEq are often implemented in terms of one another, so
66 folding not (self == other) into self != other breaks implementation
67 of !=. Detecting such cases doesn't seem worthwhile.
68 Python uses </> for 'is subset'/'is superset' operations on sets.
69 They don't satisfy not folding laws. */
70 cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0);
71 switch (op) {
72 case Is:
73 op = IsNot;
74 break;
75 case IsNot:
76 op = Is;
77 break;
78 case In:
79 op = NotIn;
80 break;
81 case NotIn:
82 op = In;
83 break;
84 // The remaining comparison operators can't be safely inverted
85 case Eq:
86 case NotEq:
87 case Lt:
88 case LtE:
89 case Gt:
90 case GtE:
91 op = 0; // The AST enums leave "0" free as an "unused" marker
92 break;
93 // No default case, so the compiler will emit a warning if new
94 // comparison operators are added without being handled here
95 }
96 if (op) {
97 asdl_seq_SET(arg->v.Compare.ops, 0, op);
98 COPY_NODE(node, arg);
99 return 1;
100 }
101 }
102 return 1;
103 }
104
105 typedef PyObject *(*unary_op)(PyObject*);
106 static const unary_op ops[] = {
107 [Invert] = PyNumber_Invert,
108 [Not] = unary_not,
109 [UAdd] = PyNumber_Positive,
110 [USub] = PyNumber_Negative,
111 };
112 PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
113 return make_const(node, newval, arena);
114 }
115
116 /* Check whether a collection doesn't containing too much items (including
117 subcollections). This protects from creating a constant that needs
118 too much time for calculating a hash.
119 "limit" is the maximal number of items.
120 Returns the negative number if the total number of items exceeds the
121 limit. Otherwise returns the limit minus the total number of items.
122 */
123
124 static Py_ssize_t
check_complexity(PyObject * obj,Py_ssize_t limit)125 check_complexity(PyObject *obj, Py_ssize_t limit)
126 {
127 if (PyTuple_Check(obj)) {
128 Py_ssize_t i;
129 limit -= PyTuple_GET_SIZE(obj);
130 for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
131 limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
132 }
133 return limit;
134 }
135 else if (PyFrozenSet_Check(obj)) {
136 Py_ssize_t i = 0;
137 PyObject *item;
138 Py_hash_t hash;
139 limit -= PySet_GET_SIZE(obj);
140 while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
141 limit = check_complexity(item, limit);
142 }
143 }
144 return limit;
145 }
146
147 #define MAX_INT_SIZE 128 /* bits */
148 #define MAX_COLLECTION_SIZE 256 /* items */
149 #define MAX_STR_SIZE 4096 /* characters */
150 #define MAX_TOTAL_ITEMS 1024 /* including nested collections */
151
152 static PyObject *
safe_multiply(PyObject * v,PyObject * w)153 safe_multiply(PyObject *v, PyObject *w)
154 {
155 if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
156 size_t vbits = _PyLong_NumBits(v);
157 size_t wbits = _PyLong_NumBits(w);
158 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
159 return NULL;
160 }
161 if (vbits + wbits > MAX_INT_SIZE) {
162 return NULL;
163 }
164 }
165 else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
166 Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
167 PySet_GET_SIZE(w);
168 if (size) {
169 long n = PyLong_AsLong(v);
170 if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
171 return NULL;
172 }
173 if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
174 return NULL;
175 }
176 }
177 }
178 else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
179 Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
180 PyBytes_GET_SIZE(w);
181 if (size) {
182 long n = PyLong_AsLong(v);
183 if (n < 0 || n > MAX_STR_SIZE / size) {
184 return NULL;
185 }
186 }
187 }
188 else if (PyLong_Check(w) &&
189 (PyTuple_Check(v) || PyFrozenSet_Check(v) ||
190 PyUnicode_Check(v) || PyBytes_Check(v)))
191 {
192 return safe_multiply(w, v);
193 }
194
195 return PyNumber_Multiply(v, w);
196 }
197
198 static PyObject *
safe_power(PyObject * v,PyObject * w)199 safe_power(PyObject *v, PyObject *w)
200 {
201 if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w) > 0) {
202 size_t vbits = _PyLong_NumBits(v);
203 size_t wbits = PyLong_AsSize_t(w);
204 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
205 return NULL;
206 }
207 if (vbits > MAX_INT_SIZE / wbits) {
208 return NULL;
209 }
210 }
211
212 return PyNumber_Power(v, w, Py_None);
213 }
214
215 static PyObject *
safe_lshift(PyObject * v,PyObject * w)216 safe_lshift(PyObject *v, PyObject *w)
217 {
218 if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
219 size_t vbits = _PyLong_NumBits(v);
220 size_t wbits = PyLong_AsSize_t(w);
221 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
222 return NULL;
223 }
224 if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
225 return NULL;
226 }
227 }
228
229 return PyNumber_Lshift(v, w);
230 }
231
232 static PyObject *
safe_mod(PyObject * v,PyObject * w)233 safe_mod(PyObject *v, PyObject *w)
234 {
235 if (PyUnicode_Check(v) || PyBytes_Check(v)) {
236 return NULL;
237 }
238
239 return PyNumber_Remainder(v, w);
240 }
241
242
243 static expr_ty
parse_literal(PyObject * fmt,Py_ssize_t * ppos,PyArena * arena)244 parse_literal(PyObject *fmt, Py_ssize_t *ppos, PyArena *arena)
245 {
246 const void *data = PyUnicode_DATA(fmt);
247 int kind = PyUnicode_KIND(fmt);
248 Py_ssize_t size = PyUnicode_GET_LENGTH(fmt);
249 Py_ssize_t start, pos;
250 int has_percents = 0;
251 start = pos = *ppos;
252 while (pos < size) {
253 if (PyUnicode_READ(kind, data, pos) != '%') {
254 pos++;
255 }
256 else if (pos+1 < size && PyUnicode_READ(kind, data, pos+1) == '%') {
257 has_percents = 1;
258 pos += 2;
259 }
260 else {
261 break;
262 }
263 }
264 *ppos = pos;
265 if (pos == start) {
266 return NULL;
267 }
268 PyObject *str = PyUnicode_Substring(fmt, start, pos);
269 /* str = str.replace('%%', '%') */
270 if (str && has_percents) {
271 _Py_DECLARE_STR(percent, "%");
272 _Py_DECLARE_STR(dbl_percent, "%%");
273 Py_SETREF(str, PyUnicode_Replace(str, &_Py_STR(dbl_percent),
274 &_Py_STR(percent), -1));
275 }
276 if (!str) {
277 return NULL;
278 }
279
280 if (_PyArena_AddPyObject(arena, str) < 0) {
281 Py_DECREF(str);
282 return NULL;
283 }
284 return _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
285 }
286
287 #define MAXDIGITS 3
288
289 static int
simple_format_arg_parse(PyObject * fmt,Py_ssize_t * ppos,int * spec,int * flags,int * width,int * prec)290 simple_format_arg_parse(PyObject *fmt, Py_ssize_t *ppos,
291 int *spec, int *flags, int *width, int *prec)
292 {
293 Py_ssize_t pos = *ppos, len = PyUnicode_GET_LENGTH(fmt);
294 Py_UCS4 ch;
295
296 #define NEXTC do { \
297 if (pos >= len) { \
298 return 0; \
299 } \
300 ch = PyUnicode_READ_CHAR(fmt, pos); \
301 pos++; \
302 } while (0)
303
304 *flags = 0;
305 while (1) {
306 NEXTC;
307 switch (ch) {
308 case '-': *flags |= F_LJUST; continue;
309 case '+': *flags |= F_SIGN; continue;
310 case ' ': *flags |= F_BLANK; continue;
311 case '#': *flags |= F_ALT; continue;
312 case '0': *flags |= F_ZERO; continue;
313 }
314 break;
315 }
316 if ('0' <= ch && ch <= '9') {
317 *width = 0;
318 int digits = 0;
319 while ('0' <= ch && ch <= '9') {
320 *width = *width * 10 + (ch - '0');
321 NEXTC;
322 if (++digits >= MAXDIGITS) {
323 return 0;
324 }
325 }
326 }
327
328 if (ch == '.') {
329 NEXTC;
330 *prec = 0;
331 if ('0' <= ch && ch <= '9') {
332 int digits = 0;
333 while ('0' <= ch && ch <= '9') {
334 *prec = *prec * 10 + (ch - '0');
335 NEXTC;
336 if (++digits >= MAXDIGITS) {
337 return 0;
338 }
339 }
340 }
341 }
342 *spec = ch;
343 *ppos = pos;
344 return 1;
345
346 #undef NEXTC
347 }
348
349 static expr_ty
parse_format(PyObject * fmt,Py_ssize_t * ppos,expr_ty arg,PyArena * arena)350 parse_format(PyObject *fmt, Py_ssize_t *ppos, expr_ty arg, PyArena *arena)
351 {
352 int spec, flags, width = -1, prec = -1;
353 if (!simple_format_arg_parse(fmt, ppos, &spec, &flags, &width, &prec)) {
354 // Unsupported format.
355 return NULL;
356 }
357 if (spec == 's' || spec == 'r' || spec == 'a') {
358 char buf[1 + MAXDIGITS + 1 + MAXDIGITS + 1], *p = buf;
359 if (!(flags & F_LJUST) && width > 0) {
360 *p++ = '>';
361 }
362 if (width >= 0) {
363 p += snprintf(p, MAXDIGITS + 1, "%d", width);
364 }
365 if (prec >= 0) {
366 p += snprintf(p, MAXDIGITS + 2, ".%d", prec);
367 }
368 expr_ty format_spec = NULL;
369 if (p != buf) {
370 PyObject *str = PyUnicode_FromString(buf);
371 if (str == NULL) {
372 return NULL;
373 }
374 if (_PyArena_AddPyObject(arena, str) < 0) {
375 Py_DECREF(str);
376 return NULL;
377 }
378 format_spec = _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
379 if (format_spec == NULL) {
380 return NULL;
381 }
382 }
383 return _PyAST_FormattedValue(arg, spec, format_spec,
384 arg->lineno, arg->col_offset,
385 arg->end_lineno, arg->end_col_offset,
386 arena);
387 }
388 // Unsupported format.
389 return NULL;
390 }
391
392 static int
optimize_format(expr_ty node,PyObject * fmt,asdl_expr_seq * elts,PyArena * arena)393 optimize_format(expr_ty node, PyObject *fmt, asdl_expr_seq *elts, PyArena *arena)
394 {
395 Py_ssize_t pos = 0;
396 Py_ssize_t cnt = 0;
397 asdl_expr_seq *seq = _Py_asdl_expr_seq_new(asdl_seq_LEN(elts) * 2 + 1, arena);
398 if (!seq) {
399 return 0;
400 }
401 seq->size = 0;
402
403 while (1) {
404 expr_ty lit = parse_literal(fmt, &pos, arena);
405 if (lit) {
406 asdl_seq_SET(seq, seq->size++, lit);
407 }
408 else if (PyErr_Occurred()) {
409 return 0;
410 }
411
412 if (pos >= PyUnicode_GET_LENGTH(fmt)) {
413 break;
414 }
415 if (cnt >= asdl_seq_LEN(elts)) {
416 // More format units than items.
417 return 1;
418 }
419 assert(PyUnicode_READ_CHAR(fmt, pos) == '%');
420 pos++;
421 expr_ty expr = parse_format(fmt, &pos, asdl_seq_GET(elts, cnt), arena);
422 cnt++;
423 if (!expr) {
424 return !PyErr_Occurred();
425 }
426 asdl_seq_SET(seq, seq->size++, expr);
427 }
428 if (cnt < asdl_seq_LEN(elts)) {
429 // More items than format units.
430 return 1;
431 }
432 expr_ty res = _PyAST_JoinedStr(seq,
433 node->lineno, node->col_offset,
434 node->end_lineno, node->end_col_offset,
435 arena);
436 if (!res) {
437 return 0;
438 }
439 COPY_NODE(node, res);
440 // PySys_FormatStderr("format = %R\n", fmt);
441 return 1;
442 }
443
444 static int
fold_binop(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)445 fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
446 {
447 expr_ty lhs, rhs;
448 lhs = node->v.BinOp.left;
449 rhs = node->v.BinOp.right;
450 if (lhs->kind != Constant_kind) {
451 return 1;
452 }
453 PyObject *lv = lhs->v.Constant.value;
454
455 if (node->v.BinOp.op == Mod &&
456 rhs->kind == Tuple_kind &&
457 PyUnicode_Check(lv) &&
458 !has_starred(rhs->v.Tuple.elts))
459 {
460 return optimize_format(node, lv, rhs->v.Tuple.elts, arena);
461 }
462
463 if (rhs->kind != Constant_kind) {
464 return 1;
465 }
466
467 PyObject *rv = rhs->v.Constant.value;
468 PyObject *newval = NULL;
469
470 switch (node->v.BinOp.op) {
471 case Add:
472 newval = PyNumber_Add(lv, rv);
473 break;
474 case Sub:
475 newval = PyNumber_Subtract(lv, rv);
476 break;
477 case Mult:
478 newval = safe_multiply(lv, rv);
479 break;
480 case Div:
481 newval = PyNumber_TrueDivide(lv, rv);
482 break;
483 case FloorDiv:
484 newval = PyNumber_FloorDivide(lv, rv);
485 break;
486 case Mod:
487 newval = safe_mod(lv, rv);
488 break;
489 case Pow:
490 newval = safe_power(lv, rv);
491 break;
492 case LShift:
493 newval = safe_lshift(lv, rv);
494 break;
495 case RShift:
496 newval = PyNumber_Rshift(lv, rv);
497 break;
498 case BitOr:
499 newval = PyNumber_Or(lv, rv);
500 break;
501 case BitXor:
502 newval = PyNumber_Xor(lv, rv);
503 break;
504 case BitAnd:
505 newval = PyNumber_And(lv, rv);
506 break;
507 // No builtin constants implement the following operators
508 case MatMult:
509 return 1;
510 // No default case, so the compiler will emit a warning if new binary
511 // operators are added without being handled here
512 }
513
514 return make_const(node, newval, arena);
515 }
516
517 static PyObject*
make_const_tuple(asdl_expr_seq * elts)518 make_const_tuple(asdl_expr_seq *elts)
519 {
520 for (int i = 0; i < asdl_seq_LEN(elts); i++) {
521 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
522 if (e->kind != Constant_kind) {
523 return NULL;
524 }
525 }
526
527 PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
528 if (newval == NULL) {
529 return NULL;
530 }
531
532 for (int i = 0; i < asdl_seq_LEN(elts); i++) {
533 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
534 PyObject *v = e->v.Constant.value;
535 Py_INCREF(v);
536 PyTuple_SET_ITEM(newval, i, v);
537 }
538 return newval;
539 }
540
541 static int
fold_tuple(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)542 fold_tuple(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
543 {
544 PyObject *newval;
545
546 if (node->v.Tuple.ctx != Load)
547 return 1;
548
549 newval = make_const_tuple(node->v.Tuple.elts);
550 return make_const(node, newval, arena);
551 }
552
553 static int
fold_subscr(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)554 fold_subscr(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
555 {
556 PyObject *newval;
557 expr_ty arg, idx;
558
559 arg = node->v.Subscript.value;
560 idx = node->v.Subscript.slice;
561 if (node->v.Subscript.ctx != Load ||
562 arg->kind != Constant_kind ||
563 idx->kind != Constant_kind)
564 {
565 return 1;
566 }
567
568 newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value);
569 return make_const(node, newval, arena);
570 }
571
572 /* Change literal list or set of constants into constant
573 tuple or frozenset respectively. Change literal list of
574 non-constants into tuple.
575 Used for right operand of "in" and "not in" tests and for iterable
576 in "for" loop and comprehensions.
577 */
578 static int
fold_iter(expr_ty arg,PyArena * arena,_PyASTOptimizeState * state)579 fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state)
580 {
581 PyObject *newval;
582 if (arg->kind == List_kind) {
583 /* First change a list into tuple. */
584 asdl_expr_seq *elts = arg->v.List.elts;
585 if (has_starred(elts)) {
586 return 1;
587 }
588 expr_context_ty ctx = arg->v.List.ctx;
589 arg->kind = Tuple_kind;
590 arg->v.Tuple.elts = elts;
591 arg->v.Tuple.ctx = ctx;
592 /* Try to create a constant tuple. */
593 newval = make_const_tuple(elts);
594 }
595 else if (arg->kind == Set_kind) {
596 newval = make_const_tuple(arg->v.Set.elts);
597 if (newval) {
598 Py_SETREF(newval, PyFrozenSet_New(newval));
599 }
600 }
601 else {
602 return 1;
603 }
604 return make_const(arg, newval, arena);
605 }
606
607 static int
fold_compare(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)608 fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
609 {
610 asdl_int_seq *ops;
611 asdl_expr_seq *args;
612 Py_ssize_t i;
613
614 ops = node->v.Compare.ops;
615 args = node->v.Compare.comparators;
616 /* Change literal list or set in 'in' or 'not in' into
617 tuple or frozenset respectively. */
618 i = asdl_seq_LEN(ops) - 1;
619 int op = asdl_seq_GET(ops, i);
620 if (op == In || op == NotIn) {
621 if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) {
622 return 0;
623 }
624 }
625 return 1;
626 }
627
628 static int astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
629 static int astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
630 static int astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
631 static int astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
632 static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
633 static int astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
634 static int astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
635 static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
636 static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
637 static int astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
638 static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
639
640 #define CALL(FUNC, TYPE, ARG) \
641 if (!FUNC((ARG), ctx_, state)) \
642 return 0;
643
644 #define CALL_OPT(FUNC, TYPE, ARG) \
645 if ((ARG) != NULL && !FUNC((ARG), ctx_, state)) \
646 return 0;
647
648 #define CALL_SEQ(FUNC, TYPE, ARG) { \
649 int i; \
650 asdl_ ## TYPE ## _seq *seq = (ARG); /* avoid variable capture */ \
651 for (i = 0; i < asdl_seq_LEN(seq); i++) { \
652 TYPE ## _ty elt = (TYPE ## _ty)asdl_seq_GET(seq, i); \
653 if (elt != NULL && !FUNC(elt, ctx_, state)) \
654 return 0; \
655 } \
656 }
657
658
659 static int
astfold_body(asdl_stmt_seq * stmts,PyArena * ctx_,_PyASTOptimizeState * state)660 astfold_body(asdl_stmt_seq *stmts, PyArena *ctx_, _PyASTOptimizeState *state)
661 {
662 int docstring = _PyAST_GetDocString(stmts) != NULL;
663 CALL_SEQ(astfold_stmt, stmt, stmts);
664 if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
665 stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
666 asdl_expr_seq *values = _Py_asdl_expr_seq_new(1, ctx_);
667 if (!values) {
668 return 0;
669 }
670 asdl_seq_SET(values, 0, st->v.Expr.value);
671 expr_ty expr = _PyAST_JoinedStr(values, st->lineno, st->col_offset,
672 st->end_lineno, st->end_col_offset,
673 ctx_);
674 if (!expr) {
675 return 0;
676 }
677 st->v.Expr.value = expr;
678 }
679 return 1;
680 }
681
682 static int
astfold_mod(mod_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)683 astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
684 {
685 switch (node_->kind) {
686 case Module_kind:
687 CALL(astfold_body, asdl_seq, node_->v.Module.body);
688 break;
689 case Interactive_kind:
690 CALL_SEQ(astfold_stmt, stmt, node_->v.Interactive.body);
691 break;
692 case Expression_kind:
693 CALL(astfold_expr, expr_ty, node_->v.Expression.body);
694 break;
695 // The following top level nodes don't participate in constant folding
696 case FunctionType_kind:
697 break;
698 // No default case, so the compiler will emit a warning if new top level
699 // compilation nodes are added without being handled here
700 }
701 return 1;
702 }
703
704 static int
astfold_expr(expr_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)705 astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
706 {
707 if (++state->recursion_depth > state->recursion_limit) {
708 PyErr_SetString(PyExc_RecursionError,
709 "maximum recursion depth exceeded during compilation");
710 return 0;
711 }
712 switch (node_->kind) {
713 case BoolOp_kind:
714 CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
715 break;
716 case BinOp_kind:
717 CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
718 CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
719 CALL(fold_binop, expr_ty, node_);
720 break;
721 case UnaryOp_kind:
722 CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
723 CALL(fold_unaryop, expr_ty, node_);
724 break;
725 case Lambda_kind:
726 CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
727 CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
728 break;
729 case IfExp_kind:
730 CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
731 CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
732 CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
733 break;
734 case Dict_kind:
735 CALL_SEQ(astfold_expr, expr, node_->v.Dict.keys);
736 CALL_SEQ(astfold_expr, expr, node_->v.Dict.values);
737 break;
738 case Set_kind:
739 CALL_SEQ(astfold_expr, expr, node_->v.Set.elts);
740 break;
741 case ListComp_kind:
742 CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
743 CALL_SEQ(astfold_comprehension, comprehension, node_->v.ListComp.generators);
744 break;
745 case SetComp_kind:
746 CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
747 CALL_SEQ(astfold_comprehension, comprehension, node_->v.SetComp.generators);
748 break;
749 case DictComp_kind:
750 CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
751 CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
752 CALL_SEQ(astfold_comprehension, comprehension, node_->v.DictComp.generators);
753 break;
754 case GeneratorExp_kind:
755 CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
756 CALL_SEQ(astfold_comprehension, comprehension, node_->v.GeneratorExp.generators);
757 break;
758 case Await_kind:
759 CALL(astfold_expr, expr_ty, node_->v.Await.value);
760 break;
761 case Yield_kind:
762 CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
763 break;
764 case YieldFrom_kind:
765 CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
766 break;
767 case Compare_kind:
768 CALL(astfold_expr, expr_ty, node_->v.Compare.left);
769 CALL_SEQ(astfold_expr, expr, node_->v.Compare.comparators);
770 CALL(fold_compare, expr_ty, node_);
771 break;
772 case Call_kind:
773 CALL(astfold_expr, expr_ty, node_->v.Call.func);
774 CALL_SEQ(astfold_expr, expr, node_->v.Call.args);
775 CALL_SEQ(astfold_keyword, keyword, node_->v.Call.keywords);
776 break;
777 case FormattedValue_kind:
778 CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
779 CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
780 break;
781 case JoinedStr_kind:
782 CALL_SEQ(astfold_expr, expr, node_->v.JoinedStr.values);
783 break;
784 case Attribute_kind:
785 CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
786 break;
787 case Subscript_kind:
788 CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
789 CALL(astfold_expr, expr_ty, node_->v.Subscript.slice);
790 CALL(fold_subscr, expr_ty, node_);
791 break;
792 case Starred_kind:
793 CALL(astfold_expr, expr_ty, node_->v.Starred.value);
794 break;
795 case Slice_kind:
796 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
797 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
798 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
799 break;
800 case List_kind:
801 CALL_SEQ(astfold_expr, expr, node_->v.List.elts);
802 break;
803 case Tuple_kind:
804 CALL_SEQ(astfold_expr, expr, node_->v.Tuple.elts);
805 CALL(fold_tuple, expr_ty, node_);
806 break;
807 case Name_kind:
808 if (node_->v.Name.ctx == Load &&
809 _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
810 state->recursion_depth--;
811 return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
812 }
813 break;
814 case NamedExpr_kind:
815 CALL(astfold_expr, expr_ty, node_->v.NamedExpr.value);
816 break;
817 case Constant_kind:
818 // Already a constant, nothing further to do
819 break;
820 // No default case, so the compiler will emit a warning if new expression
821 // kinds are added without being handled here
822 }
823 state->recursion_depth--;
824 return 1;
825 }
826
827 static int
astfold_keyword(keyword_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)828 astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
829 {
830 CALL(astfold_expr, expr_ty, node_->value);
831 return 1;
832 }
833
834 static int
astfold_comprehension(comprehension_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)835 astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
836 {
837 CALL(astfold_expr, expr_ty, node_->target);
838 CALL(astfold_expr, expr_ty, node_->iter);
839 CALL_SEQ(astfold_expr, expr, node_->ifs);
840
841 CALL(fold_iter, expr_ty, node_->iter);
842 return 1;
843 }
844
845 static int
astfold_arguments(arguments_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)846 astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
847 {
848 CALL_SEQ(astfold_arg, arg, node_->posonlyargs);
849 CALL_SEQ(astfold_arg, arg, node_->args);
850 CALL_OPT(astfold_arg, arg_ty, node_->vararg);
851 CALL_SEQ(astfold_arg, arg, node_->kwonlyargs);
852 CALL_SEQ(astfold_expr, expr, node_->kw_defaults);
853 CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
854 CALL_SEQ(astfold_expr, expr, node_->defaults);
855 return 1;
856 }
857
858 static int
astfold_arg(arg_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)859 astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
860 {
861 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
862 CALL_OPT(astfold_expr, expr_ty, node_->annotation);
863 }
864 return 1;
865 }
866
867 static int
astfold_stmt(stmt_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)868 astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
869 {
870 if (++state->recursion_depth > state->recursion_limit) {
871 PyErr_SetString(PyExc_RecursionError,
872 "maximum recursion depth exceeded during compilation");
873 return 0;
874 }
875 switch (node_->kind) {
876 case FunctionDef_kind:
877 CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
878 CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
879 CALL_SEQ(astfold_expr, expr, node_->v.FunctionDef.decorator_list);
880 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
881 CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
882 }
883 break;
884 case AsyncFunctionDef_kind:
885 CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
886 CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
887 CALL_SEQ(astfold_expr, expr, node_->v.AsyncFunctionDef.decorator_list);
888 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
889 CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
890 }
891 break;
892 case ClassDef_kind:
893 CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.bases);
894 CALL_SEQ(astfold_keyword, keyword, node_->v.ClassDef.keywords);
895 CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
896 CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.decorator_list);
897 break;
898 case Return_kind:
899 CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
900 break;
901 case Delete_kind:
902 CALL_SEQ(astfold_expr, expr, node_->v.Delete.targets);
903 break;
904 case Assign_kind:
905 CALL_SEQ(astfold_expr, expr, node_->v.Assign.targets);
906 CALL(astfold_expr, expr_ty, node_->v.Assign.value);
907 break;
908 case AugAssign_kind:
909 CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
910 CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
911 break;
912 case AnnAssign_kind:
913 CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
914 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
915 CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
916 }
917 CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
918 break;
919 case For_kind:
920 CALL(astfold_expr, expr_ty, node_->v.For.target);
921 CALL(astfold_expr, expr_ty, node_->v.For.iter);
922 CALL_SEQ(astfold_stmt, stmt, node_->v.For.body);
923 CALL_SEQ(astfold_stmt, stmt, node_->v.For.orelse);
924
925 CALL(fold_iter, expr_ty, node_->v.For.iter);
926 break;
927 case AsyncFor_kind:
928 CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
929 CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
930 CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.body);
931 CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.orelse);
932 break;
933 case While_kind:
934 CALL(astfold_expr, expr_ty, node_->v.While.test);
935 CALL_SEQ(astfold_stmt, stmt, node_->v.While.body);
936 CALL_SEQ(astfold_stmt, stmt, node_->v.While.orelse);
937 break;
938 case If_kind:
939 CALL(astfold_expr, expr_ty, node_->v.If.test);
940 CALL_SEQ(astfold_stmt, stmt, node_->v.If.body);
941 CALL_SEQ(astfold_stmt, stmt, node_->v.If.orelse);
942 break;
943 case With_kind:
944 CALL_SEQ(astfold_withitem, withitem, node_->v.With.items);
945 CALL_SEQ(astfold_stmt, stmt, node_->v.With.body);
946 break;
947 case AsyncWith_kind:
948 CALL_SEQ(astfold_withitem, withitem, node_->v.AsyncWith.items);
949 CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncWith.body);
950 break;
951 case Raise_kind:
952 CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
953 CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
954 break;
955 case Try_kind:
956 CALL_SEQ(astfold_stmt, stmt, node_->v.Try.body);
957 CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.Try.handlers);
958 CALL_SEQ(astfold_stmt, stmt, node_->v.Try.orelse);
959 CALL_SEQ(astfold_stmt, stmt, node_->v.Try.finalbody);
960 break;
961 case TryStar_kind:
962 CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.body);
963 CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.TryStar.handlers);
964 CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.orelse);
965 CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.finalbody);
966 break;
967 case Assert_kind:
968 CALL(astfold_expr, expr_ty, node_->v.Assert.test);
969 CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
970 break;
971 case Expr_kind:
972 CALL(astfold_expr, expr_ty, node_->v.Expr.value);
973 break;
974 case Match_kind:
975 CALL(astfold_expr, expr_ty, node_->v.Match.subject);
976 CALL_SEQ(astfold_match_case, match_case, node_->v.Match.cases);
977 break;
978 // The following statements don't contain any subexpressions to be folded
979 case Import_kind:
980 case ImportFrom_kind:
981 case Global_kind:
982 case Nonlocal_kind:
983 case Pass_kind:
984 case Break_kind:
985 case Continue_kind:
986 break;
987 // No default case, so the compiler will emit a warning if new statement
988 // kinds are added without being handled here
989 }
990 state->recursion_depth--;
991 return 1;
992 }
993
994 static int
astfold_excepthandler(excepthandler_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)995 astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
996 {
997 switch (node_->kind) {
998 case ExceptHandler_kind:
999 CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
1000 CALL_SEQ(astfold_stmt, stmt, node_->v.ExceptHandler.body);
1001 break;
1002 // No default case, so the compiler will emit a warning if new handler
1003 // kinds are added without being handled here
1004 }
1005 return 1;
1006 }
1007
1008 static int
astfold_withitem(withitem_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1009 astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1010 {
1011 CALL(astfold_expr, expr_ty, node_->context_expr);
1012 CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
1013 return 1;
1014 }
1015
1016 static int
astfold_pattern(pattern_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1017 astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1018 {
1019 // Currently, this is really only used to form complex/negative numeric
1020 // constants in MatchValue and MatchMapping nodes
1021 // We still recurse into all subexpressions and subpatterns anyway
1022 if (++state->recursion_depth > state->recursion_limit) {
1023 PyErr_SetString(PyExc_RecursionError,
1024 "maximum recursion depth exceeded during compilation");
1025 return 0;
1026 }
1027 switch (node_->kind) {
1028 case MatchValue_kind:
1029 CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
1030 break;
1031 case MatchSingleton_kind:
1032 break;
1033 case MatchSequence_kind:
1034 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns);
1035 break;
1036 case MatchMapping_kind:
1037 CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys);
1038 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns);
1039 break;
1040 case MatchClass_kind:
1041 CALL(astfold_expr, expr_ty, node_->v.MatchClass.cls);
1042 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.patterns);
1043 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.kwd_patterns);
1044 break;
1045 case MatchStar_kind:
1046 break;
1047 case MatchAs_kind:
1048 if (node_->v.MatchAs.pattern) {
1049 CALL(astfold_pattern, pattern_ty, node_->v.MatchAs.pattern);
1050 }
1051 break;
1052 case MatchOr_kind:
1053 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchOr.patterns);
1054 break;
1055 // No default case, so the compiler will emit a warning if new pattern
1056 // kinds are added without being handled here
1057 }
1058 state->recursion_depth--;
1059 return 1;
1060 }
1061
1062 static int
astfold_match_case(match_case_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)1063 astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1064 {
1065 CALL(astfold_pattern, expr_ty, node_->pattern);
1066 CALL_OPT(astfold_expr, expr_ty, node_->guard);
1067 CALL_SEQ(astfold_stmt, stmt, node_->body);
1068 return 1;
1069 }
1070
1071 #undef CALL
1072 #undef CALL_OPT
1073 #undef CALL_SEQ
1074
1075 /* See comments in symtable.c. */
1076 #define COMPILER_STACK_FRAME_SCALE 3
1077
1078 int
_PyAST_Optimize(mod_ty mod,PyArena * arena,_PyASTOptimizeState * state)1079 _PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
1080 {
1081 PyThreadState *tstate;
1082 int recursion_limit = Py_GetRecursionLimit();
1083 int starting_recursion_depth;
1084
1085 /* Setup recursion depth check counters */
1086 tstate = _PyThreadState_GET();
1087 if (!tstate) {
1088 return 0;
1089 }
1090 /* Be careful here to prevent overflow. */
1091 int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
1092 starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1093 recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
1094 state->recursion_depth = starting_recursion_depth;
1095 state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1096 recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
1097
1098 int ret = astfold_mod(mod, arena, state);
1099 assert(ret || PyErr_Occurred());
1100
1101 /* Check that the recursion depth counting balanced correctly */
1102 if (ret && state->recursion_depth != starting_recursion_depth) {
1103 PyErr_Format(PyExc_SystemError,
1104 "AST optimizer recursion depth mismatch (before=%d, after=%d)",
1105 starting_recursion_depth, state->recursion_depth);
1106 return 0;
1107 }
1108
1109 return ret;
1110 }
1111