1 /*
2  * This file exposes PyAST_Validate interface to check the integrity
3  * of the given abstract syntax tree (potentially constructed manually).
4  */
5 #include "Python.h"
6 #include "pycore_ast.h"           // asdl_stmt_seq
7 #include "pycore_pystate.h"       // _PyThreadState_GET()
8 
9 #include <assert.h>
10 #include <stdbool.h>
11 
12 struct validator {
13     int recursion_depth;            /* current recursion depth */
14     int recursion_limit;            /* recursion limit */
15 };
16 
17 static int validate_stmts(struct validator *, asdl_stmt_seq *);
18 static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int);
19 static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
20 static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
21 static int validate_stmt(struct validator *, stmt_ty);
22 static int validate_expr(struct validator *, expr_ty, expr_context_ty);
23 static int validate_pattern(struct validator *, pattern_ty, int);
24 
25 #define VALIDATE_POSITIONS(node) \
26     if (node->lineno > node->end_lineno) { \
27         PyErr_Format(PyExc_ValueError, \
28                      "AST node line range (%d, %d) is not valid", \
29                      node->lineno, node->end_lineno); \
30         return 0; \
31     } \
32     if ((node->lineno < 0 && node->end_lineno != node->lineno) || \
33         (node->col_offset < 0 && node->col_offset != node->end_col_offset)) { \
34         PyErr_Format(PyExc_ValueError, \
35                      "AST node column range (%d, %d) for line range (%d, %d) is not valid", \
36                      node->col_offset, node->end_col_offset, node->lineno, node->end_lineno); \
37         return 0; \
38     } \
39     if (node->lineno == node->end_lineno && node->col_offset > node->end_col_offset) { \
40         PyErr_Format(PyExc_ValueError, \
41                      "line %d, column %d-%d is not a valid range", \
42                      node->lineno, node->col_offset, node->end_col_offset); \
43         return 0; \
44     }
45 
46 static int
validate_name(PyObject * name)47 validate_name(PyObject *name)
48 {
49     assert(!PyErr_Occurred());
50     assert(PyUnicode_Check(name));
51     static const char * const forbidden[] = {
52         "None",
53         "True",
54         "False",
55         NULL
56     };
57     for (int i = 0; forbidden[i] != NULL; i++) {
58         if (_PyUnicode_EqualToASCIIString(name, forbidden[i])) {
59             PyErr_Format(PyExc_ValueError, "identifier field can't represent '%s' constant", forbidden[i]);
60             return 0;
61         }
62     }
63     return 1;
64 }
65 
66 static int
validate_comprehension(struct validator * state,asdl_comprehension_seq * gens)67 validate_comprehension(struct validator *state, asdl_comprehension_seq *gens)
68 {
69     assert(!PyErr_Occurred());
70     if (!asdl_seq_LEN(gens)) {
71         PyErr_SetString(PyExc_ValueError, "comprehension with no generators");
72         return 0;
73     }
74     for (Py_ssize_t i = 0; i < asdl_seq_LEN(gens); i++) {
75         comprehension_ty comp = asdl_seq_GET(gens, i);
76         if (!validate_expr(state, comp->target, Store) ||
77             !validate_expr(state, comp->iter, Load) ||
78             !validate_exprs(state, comp->ifs, Load, 0))
79             return 0;
80     }
81     return 1;
82 }
83 
84 static int
validate_keywords(struct validator * state,asdl_keyword_seq * keywords)85 validate_keywords(struct validator *state, asdl_keyword_seq *keywords)
86 {
87     assert(!PyErr_Occurred());
88     for (Py_ssize_t i = 0; i < asdl_seq_LEN(keywords); i++)
89         if (!validate_expr(state, (asdl_seq_GET(keywords, i))->value, Load))
90             return 0;
91     return 1;
92 }
93 
94 static int
validate_args(struct validator * state,asdl_arg_seq * args)95 validate_args(struct validator *state, asdl_arg_seq *args)
96 {
97     assert(!PyErr_Occurred());
98     for (Py_ssize_t i = 0; i < asdl_seq_LEN(args); i++) {
99         arg_ty arg = asdl_seq_GET(args, i);
100         VALIDATE_POSITIONS(arg);
101         if (arg->annotation && !validate_expr(state, arg->annotation, Load))
102             return 0;
103     }
104     return 1;
105 }
106 
107 static const char *
expr_context_name(expr_context_ty ctx)108 expr_context_name(expr_context_ty ctx)
109 {
110     switch (ctx) {
111     case Load:
112         return "Load";
113     case Store:
114         return "Store";
115     case Del:
116         return "Del";
117     // No default case so compiler emits warning for unhandled cases
118     }
119     Py_UNREACHABLE();
120 }
121 
122 static int
validate_arguments(struct validator * state,arguments_ty args)123 validate_arguments(struct validator *state, arguments_ty args)
124 {
125     assert(!PyErr_Occurred());
126     if (!validate_args(state, args->posonlyargs) || !validate_args(state, args->args)) {
127         return 0;
128     }
129     if (args->vararg && args->vararg->annotation
130         && !validate_expr(state, args->vararg->annotation, Load)) {
131             return 0;
132     }
133     if (!validate_args(state, args->kwonlyargs))
134         return 0;
135     if (args->kwarg && args->kwarg->annotation
136         && !validate_expr(state, args->kwarg->annotation, Load)) {
137             return 0;
138     }
139     if (asdl_seq_LEN(args->defaults) > asdl_seq_LEN(args->posonlyargs) + asdl_seq_LEN(args->args)) {
140         PyErr_SetString(PyExc_ValueError, "more positional defaults than args on arguments");
141         return 0;
142     }
143     if (asdl_seq_LEN(args->kw_defaults) != asdl_seq_LEN(args->kwonlyargs)) {
144         PyErr_SetString(PyExc_ValueError, "length of kwonlyargs is not the same as "
145                         "kw_defaults on arguments");
146         return 0;
147     }
148     return validate_exprs(state, args->defaults, Load, 0) && validate_exprs(state, args->kw_defaults, Load, 1);
149 }
150 
151 static int
validate_constant(struct validator * state,PyObject * value)152 validate_constant(struct validator *state, PyObject *value)
153 {
154     assert(!PyErr_Occurred());
155     if (value == Py_None || value == Py_Ellipsis)
156         return 1;
157 
158     if (PyLong_CheckExact(value)
159             || PyFloat_CheckExact(value)
160             || PyComplex_CheckExact(value)
161             || PyBool_Check(value)
162             || PyUnicode_CheckExact(value)
163             || PyBytes_CheckExact(value))
164         return 1;
165 
166     if (PyTuple_CheckExact(value) || PyFrozenSet_CheckExact(value)) {
167         if (++state->recursion_depth > state->recursion_limit) {
168             PyErr_SetString(PyExc_RecursionError,
169                             "maximum recursion depth exceeded during compilation");
170             return 0;
171         }
172 
173         PyObject *it = PyObject_GetIter(value);
174         if (it == NULL)
175             return 0;
176 
177         while (1) {
178             PyObject *item = PyIter_Next(it);
179             if (item == NULL) {
180                 if (PyErr_Occurred()) {
181                     Py_DECREF(it);
182                     return 0;
183                 }
184                 break;
185             }
186 
187             if (!validate_constant(state, item)) {
188                 Py_DECREF(it);
189                 Py_DECREF(item);
190                 return 0;
191             }
192             Py_DECREF(item);
193         }
194 
195         Py_DECREF(it);
196         --state->recursion_depth;
197         return 1;
198     }
199 
200     if (!PyErr_Occurred()) {
201         PyErr_Format(PyExc_TypeError,
202                      "got an invalid type in Constant: %s",
203                      _PyType_Name(Py_TYPE(value)));
204     }
205     return 0;
206 }
207 
208 static int
validate_expr(struct validator * state,expr_ty exp,expr_context_ty ctx)209 validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
210 {
211     assert(!PyErr_Occurred());
212     VALIDATE_POSITIONS(exp);
213     int ret = -1;
214     if (++state->recursion_depth > state->recursion_limit) {
215         PyErr_SetString(PyExc_RecursionError,
216                         "maximum recursion depth exceeded during compilation");
217         return 0;
218     }
219     int check_ctx = 1;
220     expr_context_ty actual_ctx;
221 
222     /* First check expression context. */
223     switch (exp->kind) {
224     case Attribute_kind:
225         actual_ctx = exp->v.Attribute.ctx;
226         break;
227     case Subscript_kind:
228         actual_ctx = exp->v.Subscript.ctx;
229         break;
230     case Starred_kind:
231         actual_ctx = exp->v.Starred.ctx;
232         break;
233     case Name_kind:
234         if (!validate_name(exp->v.Name.id)) {
235             return 0;
236         }
237         actual_ctx = exp->v.Name.ctx;
238         break;
239     case List_kind:
240         actual_ctx = exp->v.List.ctx;
241         break;
242     case Tuple_kind:
243         actual_ctx = exp->v.Tuple.ctx;
244         break;
245     default:
246         if (ctx != Load) {
247             PyErr_Format(PyExc_ValueError, "expression which can't be "
248                          "assigned to in %s context", expr_context_name(ctx));
249             return 0;
250         }
251         check_ctx = 0;
252         /* set actual_ctx to prevent gcc warning */
253         actual_ctx = 0;
254     }
255     if (check_ctx && actual_ctx != ctx) {
256         PyErr_Format(PyExc_ValueError, "expression must have %s context but has %s instead",
257                      expr_context_name(ctx), expr_context_name(actual_ctx));
258         return 0;
259     }
260 
261     /* Now validate expression. */
262     switch (exp->kind) {
263     case BoolOp_kind:
264         if (asdl_seq_LEN(exp->v.BoolOp.values) < 2) {
265             PyErr_SetString(PyExc_ValueError, "BoolOp with less than 2 values");
266             return 0;
267         }
268         ret = validate_exprs(state, exp->v.BoolOp.values, Load, 0);
269         break;
270     case BinOp_kind:
271         ret = validate_expr(state, exp->v.BinOp.left, Load) &&
272             validate_expr(state, exp->v.BinOp.right, Load);
273         break;
274     case UnaryOp_kind:
275         ret = validate_expr(state, exp->v.UnaryOp.operand, Load);
276         break;
277     case Lambda_kind:
278         ret = validate_arguments(state, exp->v.Lambda.args) &&
279             validate_expr(state, exp->v.Lambda.body, Load);
280         break;
281     case IfExp_kind:
282         ret = validate_expr(state, exp->v.IfExp.test, Load) &&
283             validate_expr(state, exp->v.IfExp.body, Load) &&
284             validate_expr(state, exp->v.IfExp.orelse, Load);
285         break;
286     case Dict_kind:
287         if (asdl_seq_LEN(exp->v.Dict.keys) != asdl_seq_LEN(exp->v.Dict.values)) {
288             PyErr_SetString(PyExc_ValueError,
289                             "Dict doesn't have the same number of keys as values");
290             return 0;
291         }
292         /* null_ok=1 for keys expressions to allow dict unpacking to work in
293            dict literals, i.e. ``{**{a:b}}`` */
294         ret = validate_exprs(state, exp->v.Dict.keys, Load, /*null_ok=*/ 1) &&
295             validate_exprs(state, exp->v.Dict.values, Load, /*null_ok=*/ 0);
296         break;
297     case Set_kind:
298         ret = validate_exprs(state, exp->v.Set.elts, Load, 0);
299         break;
300 #define COMP(NAME) \
301         case NAME ## _kind: \
302             ret = validate_comprehension(state, exp->v.NAME.generators) && \
303                 validate_expr(state, exp->v.NAME.elt, Load); \
304             break;
305     COMP(ListComp)
306     COMP(SetComp)
307     COMP(GeneratorExp)
308 #undef COMP
309     case DictComp_kind:
310         ret = validate_comprehension(state, exp->v.DictComp.generators) &&
311             validate_expr(state, exp->v.DictComp.key, Load) &&
312             validate_expr(state, exp->v.DictComp.value, Load);
313         break;
314     case Yield_kind:
315         ret = !exp->v.Yield.value || validate_expr(state, exp->v.Yield.value, Load);
316         break;
317     case YieldFrom_kind:
318         ret = validate_expr(state, exp->v.YieldFrom.value, Load);
319         break;
320     case Await_kind:
321         ret = validate_expr(state, exp->v.Await.value, Load);
322         break;
323     case Compare_kind:
324         if (!asdl_seq_LEN(exp->v.Compare.comparators)) {
325             PyErr_SetString(PyExc_ValueError, "Compare with no comparators");
326             return 0;
327         }
328         if (asdl_seq_LEN(exp->v.Compare.comparators) !=
329             asdl_seq_LEN(exp->v.Compare.ops)) {
330             PyErr_SetString(PyExc_ValueError, "Compare has a different number "
331                             "of comparators and operands");
332             return 0;
333         }
334         ret = validate_exprs(state, exp->v.Compare.comparators, Load, 0) &&
335             validate_expr(state, exp->v.Compare.left, Load);
336         break;
337     case Call_kind:
338         ret = validate_expr(state, exp->v.Call.func, Load) &&
339             validate_exprs(state, exp->v.Call.args, Load, 0) &&
340             validate_keywords(state, exp->v.Call.keywords);
341         break;
342     case Constant_kind:
343         if (!validate_constant(state, exp->v.Constant.value)) {
344             return 0;
345         }
346         ret = 1;
347         break;
348     case JoinedStr_kind:
349         ret = validate_exprs(state, exp->v.JoinedStr.values, Load, 0);
350         break;
351     case FormattedValue_kind:
352         if (validate_expr(state, exp->v.FormattedValue.value, Load) == 0)
353             return 0;
354         if (exp->v.FormattedValue.format_spec) {
355             ret = validate_expr(state, exp->v.FormattedValue.format_spec, Load);
356             break;
357         }
358         ret = 1;
359         break;
360     case Attribute_kind:
361         ret = validate_expr(state, exp->v.Attribute.value, Load);
362         break;
363     case Subscript_kind:
364         ret = validate_expr(state, exp->v.Subscript.slice, Load) &&
365             validate_expr(state, exp->v.Subscript.value, Load);
366         break;
367     case Starred_kind:
368         ret = validate_expr(state, exp->v.Starred.value, ctx);
369         break;
370     case Slice_kind:
371         ret = (!exp->v.Slice.lower || validate_expr(state, exp->v.Slice.lower, Load)) &&
372             (!exp->v.Slice.upper || validate_expr(state, exp->v.Slice.upper, Load)) &&
373             (!exp->v.Slice.step || validate_expr(state, exp->v.Slice.step, Load));
374         break;
375     case List_kind:
376         ret = validate_exprs(state, exp->v.List.elts, ctx, 0);
377         break;
378     case Tuple_kind:
379         ret = validate_exprs(state, exp->v.Tuple.elts, ctx, 0);
380         break;
381     case NamedExpr_kind:
382         ret = validate_expr(state, exp->v.NamedExpr.value, Load);
383         break;
384     /* This last case doesn't have any checking. */
385     case Name_kind:
386         ret = 1;
387         break;
388     // No default case so compiler emits warning for unhandled cases
389     }
390     if (ret < 0) {
391         PyErr_SetString(PyExc_SystemError, "unexpected expression");
392         ret = 0;
393     }
394     state->recursion_depth--;
395     return ret;
396 }
397 
398 
399 // Note: the ensure_literal_* functions are only used to validate a restricted
400 //       set of non-recursive literals that have already been checked with
401 //       validate_expr, so they don't accept the validator state
402 static int
ensure_literal_number(expr_ty exp,bool allow_real,bool allow_imaginary)403 ensure_literal_number(expr_ty exp, bool allow_real, bool allow_imaginary)
404 {
405     assert(exp->kind == Constant_kind);
406     PyObject *value = exp->v.Constant.value;
407     return (allow_real && PyFloat_CheckExact(value)) ||
408            (allow_real && PyLong_CheckExact(value)) ||
409            (allow_imaginary && PyComplex_CheckExact(value));
410 }
411 
412 static int
ensure_literal_negative(expr_ty exp,bool allow_real,bool allow_imaginary)413 ensure_literal_negative(expr_ty exp, bool allow_real, bool allow_imaginary)
414 {
415     assert(exp->kind == UnaryOp_kind);
416     // Must be negation ...
417     if (exp->v.UnaryOp.op != USub) {
418         return 0;
419     }
420     // ... of a constant ...
421     expr_ty operand = exp->v.UnaryOp.operand;
422     if (operand->kind != Constant_kind) {
423         return 0;
424     }
425     // ... number
426     return ensure_literal_number(operand, allow_real, allow_imaginary);
427 }
428 
429 static int
ensure_literal_complex(expr_ty exp)430 ensure_literal_complex(expr_ty exp)
431 {
432     assert(exp->kind == BinOp_kind);
433     expr_ty left = exp->v.BinOp.left;
434     expr_ty right = exp->v.BinOp.right;
435     // Ensure op is addition or subtraction
436     if (exp->v.BinOp.op != Add && exp->v.BinOp.op != Sub) {
437         return 0;
438     }
439     // Check LHS is a real number (potentially signed)
440     switch (left->kind)
441     {
442         case Constant_kind:
443             if (!ensure_literal_number(left, /*real=*/true, /*imaginary=*/false)) {
444                 return 0;
445             }
446             break;
447         case UnaryOp_kind:
448             if (!ensure_literal_negative(left, /*real=*/true, /*imaginary=*/false)) {
449                 return 0;
450             }
451             break;
452         default:
453             return 0;
454     }
455     // Check RHS is an imaginary number (no separate sign allowed)
456     switch (right->kind)
457     {
458         case Constant_kind:
459             if (!ensure_literal_number(right, /*real=*/false, /*imaginary=*/true)) {
460                 return 0;
461             }
462             break;
463         default:
464             return 0;
465     }
466     return 1;
467 }
468 
469 static int
validate_pattern_match_value(struct validator * state,expr_ty exp)470 validate_pattern_match_value(struct validator *state, expr_ty exp)
471 {
472     assert(!PyErr_Occurred());
473     if (!validate_expr(state, exp, Load)) {
474         return 0;
475     }
476 
477     switch (exp->kind)
478     {
479         case Constant_kind:
480             /* Ellipsis and immutable sequences are not allowed.
481                For True, False and None, MatchSingleton() should
482                be used */
483             if (!validate_expr(state, exp, Load)) {
484                 return 0;
485             }
486             PyObject *literal = exp->v.Constant.value;
487             if (PyLong_CheckExact(literal) || PyFloat_CheckExact(literal) ||
488                 PyBytes_CheckExact(literal) || PyComplex_CheckExact(literal) ||
489                 PyUnicode_CheckExact(literal)) {
490                 return 1;
491             }
492             PyErr_SetString(PyExc_ValueError,
493                             "unexpected constant inside of a literal pattern");
494             return 0;
495         case Attribute_kind:
496             // Constants and attribute lookups are always permitted
497             return 1;
498         case UnaryOp_kind:
499             // Negated numbers are permitted (whether real or imaginary)
500             // Compiler will complain if AST folding doesn't create a constant
501             if (ensure_literal_negative(exp, /*real=*/true, /*imaginary=*/true)) {
502                 return 1;
503             }
504             break;
505         case BinOp_kind:
506             // Complex literals are permitted
507             // Compiler will complain if AST folding doesn't create a constant
508             if (ensure_literal_complex(exp)) {
509                 return 1;
510             }
511             break;
512         case JoinedStr_kind:
513             // Handled in the later stages
514             return 1;
515         default:
516             break;
517     }
518     PyErr_SetString(PyExc_ValueError,
519                     "patterns may only match literals and attribute lookups");
520     return 0;
521 }
522 
523 static int
validate_capture(PyObject * name)524 validate_capture(PyObject *name)
525 {
526     assert(!PyErr_Occurred());
527     if (_PyUnicode_EqualToASCIIString(name, "_")) {
528         PyErr_Format(PyExc_ValueError, "can't capture name '_' in patterns");
529         return 0;
530     }
531     return validate_name(name);
532 }
533 
534 static int
validate_pattern(struct validator * state,pattern_ty p,int star_ok)535 validate_pattern(struct validator *state, pattern_ty p, int star_ok)
536 {
537     assert(!PyErr_Occurred());
538     VALIDATE_POSITIONS(p);
539     int ret = -1;
540     if (++state->recursion_depth > state->recursion_limit) {
541         PyErr_SetString(PyExc_RecursionError,
542                         "maximum recursion depth exceeded during compilation");
543         return 0;
544     }
545     switch (p->kind) {
546         case MatchValue_kind:
547             ret = validate_pattern_match_value(state, p->v.MatchValue.value);
548             break;
549         case MatchSingleton_kind:
550             ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value);
551             if (!ret) {
552                 PyErr_SetString(PyExc_ValueError,
553                                 "MatchSingleton can only contain True, False and None");
554             }
555             break;
556         case MatchSequence_kind:
557             ret = validate_patterns(state, p->v.MatchSequence.patterns, /*star_ok=*/1);
558             break;
559         case MatchMapping_kind:
560             if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) {
561                 PyErr_SetString(PyExc_ValueError,
562                                 "MatchMapping doesn't have the same number of keys as patterns");
563                 ret = 0;
564                 break;
565             }
566 
567             if (p->v.MatchMapping.rest && !validate_capture(p->v.MatchMapping.rest)) {
568                 ret = 0;
569                 break;
570             }
571 
572             asdl_expr_seq *keys = p->v.MatchMapping.keys;
573             for (Py_ssize_t i = 0; i < asdl_seq_LEN(keys); i++) {
574                 expr_ty key = asdl_seq_GET(keys, i);
575                 if (key->kind == Constant_kind) {
576                     PyObject *literal = key->v.Constant.value;
577                     if (literal == Py_None || PyBool_Check(literal)) {
578                         /* validate_pattern_match_value will ensure the key
579                            doesn't contain True, False and None but it is
580                            syntactically valid, so we will pass those on in
581                            a special case. */
582                         continue;
583                     }
584                 }
585                 if (!validate_pattern_match_value(state, key)) {
586                     ret = 0;
587                     break;
588                 }
589             }
590             if (ret == 0) {
591                 break;
592             }
593             ret = validate_patterns(state, p->v.MatchMapping.patterns, /*star_ok=*/0);
594             break;
595         case MatchClass_kind:
596             if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) {
597                 PyErr_SetString(PyExc_ValueError,
598                                 "MatchClass doesn't have the same number of keyword attributes as patterns");
599                 ret = 0;
600                 break;
601             }
602             if (!validate_expr(state, p->v.MatchClass.cls, Load)) {
603                 ret = 0;
604                 break;
605             }
606 
607             expr_ty cls = p->v.MatchClass.cls;
608             while (1) {
609                 if (cls->kind == Name_kind) {
610                     break;
611                 }
612                 else if (cls->kind == Attribute_kind) {
613                     cls = cls->v.Attribute.value;
614                     continue;
615                 }
616                 else {
617                     PyErr_SetString(PyExc_ValueError,
618                                     "MatchClass cls field can only contain Name or Attribute nodes.");
619                     ret = 0;
620                     break;
621                 }
622             }
623             if (ret == 0) {
624                 break;
625             }
626 
627             for (Py_ssize_t i = 0; i < asdl_seq_LEN(p->v.MatchClass.kwd_attrs); i++) {
628                 PyObject *identifier = asdl_seq_GET(p->v.MatchClass.kwd_attrs, i);
629                 if (!validate_name(identifier)) {
630                     ret = 0;
631                     break;
632                 }
633             }
634             if (ret == 0) {
635                 break;
636             }
637 
638             if (!validate_patterns(state, p->v.MatchClass.patterns, /*star_ok=*/0)) {
639                 ret = 0;
640                 break;
641             }
642 
643             ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0);
644             break;
645         case MatchStar_kind:
646             if (!star_ok) {
647                 PyErr_SetString(PyExc_ValueError, "can't use MatchStar here");
648                 ret = 0;
649                 break;
650             }
651             ret = p->v.MatchStar.name == NULL || validate_capture(p->v.MatchStar.name);
652             break;
653         case MatchAs_kind:
654             if (p->v.MatchAs.name && !validate_capture(p->v.MatchAs.name)) {
655                 ret = 0;
656                 break;
657             }
658             if (p->v.MatchAs.pattern == NULL) {
659                 ret = 1;
660             }
661             else if (p->v.MatchAs.name == NULL) {
662                 PyErr_SetString(PyExc_ValueError,
663                                 "MatchAs must specify a target name if a pattern is given");
664                 ret = 0;
665             }
666             else {
667                 ret = validate_pattern(state, p->v.MatchAs.pattern, /*star_ok=*/0);
668             }
669             break;
670         case MatchOr_kind:
671             if (asdl_seq_LEN(p->v.MatchOr.patterns) < 2) {
672                 PyErr_SetString(PyExc_ValueError,
673                                 "MatchOr requires at least 2 patterns");
674                 ret = 0;
675                 break;
676             }
677             ret = validate_patterns(state, p->v.MatchOr.patterns, /*star_ok=*/0);
678             break;
679     // No default case, so the compiler will emit a warning if new pattern
680     // kinds are added without being handled here
681     }
682     if (ret < 0) {
683         PyErr_SetString(PyExc_SystemError, "unexpected pattern");
684         ret = 0;
685     }
686     state->recursion_depth--;
687     return ret;
688 }
689 
690 static int
_validate_nonempty_seq(asdl_seq * seq,const char * what,const char * owner)691 _validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner)
692 {
693     if (asdl_seq_LEN(seq))
694         return 1;
695     PyErr_Format(PyExc_ValueError, "empty %s on %s", what, owner);
696     return 0;
697 }
698 #define validate_nonempty_seq(seq, what, owner) _validate_nonempty_seq((asdl_seq*)seq, what, owner)
699 
700 static int
validate_assignlist(struct validator * state,asdl_expr_seq * targets,expr_context_ty ctx)701 validate_assignlist(struct validator *state, asdl_expr_seq *targets, expr_context_ty ctx)
702 {
703     assert(!PyErr_Occurred());
704     return validate_nonempty_seq(targets, "targets", ctx == Del ? "Delete" : "Assign") &&
705         validate_exprs(state, targets, ctx, 0);
706 }
707 
708 static int
validate_body(struct validator * state,asdl_stmt_seq * body,const char * owner)709 validate_body(struct validator *state, asdl_stmt_seq *body, const char *owner)
710 {
711     assert(!PyErr_Occurred());
712     return validate_nonempty_seq(body, "body", owner) && validate_stmts(state, body);
713 }
714 
715 static int
validate_stmt(struct validator * state,stmt_ty stmt)716 validate_stmt(struct validator *state, stmt_ty stmt)
717 {
718     assert(!PyErr_Occurred());
719     VALIDATE_POSITIONS(stmt);
720     int ret = -1;
721     if (++state->recursion_depth > state->recursion_limit) {
722         PyErr_SetString(PyExc_RecursionError,
723                         "maximum recursion depth exceeded during compilation");
724         return 0;
725     }
726     switch (stmt->kind) {
727     case FunctionDef_kind:
728         ret = validate_body(state, stmt->v.FunctionDef.body, "FunctionDef") &&
729             validate_arguments(state, stmt->v.FunctionDef.args) &&
730             validate_exprs(state, stmt->v.FunctionDef.decorator_list, Load, 0) &&
731             (!stmt->v.FunctionDef.returns ||
732              validate_expr(state, stmt->v.FunctionDef.returns, Load));
733         break;
734     case ClassDef_kind:
735         ret = validate_body(state, stmt->v.ClassDef.body, "ClassDef") &&
736             validate_exprs(state, stmt->v.ClassDef.bases, Load, 0) &&
737             validate_keywords(state, stmt->v.ClassDef.keywords) &&
738             validate_exprs(state, stmt->v.ClassDef.decorator_list, Load, 0);
739         break;
740     case Return_kind:
741         ret = !stmt->v.Return.value || validate_expr(state, stmt->v.Return.value, Load);
742         break;
743     case Delete_kind:
744         ret = validate_assignlist(state, stmt->v.Delete.targets, Del);
745         break;
746     case Assign_kind:
747         ret = validate_assignlist(state, stmt->v.Assign.targets, Store) &&
748             validate_expr(state, stmt->v.Assign.value, Load);
749         break;
750     case AugAssign_kind:
751         ret = validate_expr(state, stmt->v.AugAssign.target, Store) &&
752             validate_expr(state, stmt->v.AugAssign.value, Load);
753         break;
754     case AnnAssign_kind:
755         if (stmt->v.AnnAssign.target->kind != Name_kind &&
756             stmt->v.AnnAssign.simple) {
757             PyErr_SetString(PyExc_TypeError,
758                             "AnnAssign with simple non-Name target");
759             return 0;
760         }
761         ret = validate_expr(state, stmt->v.AnnAssign.target, Store) &&
762                (!stmt->v.AnnAssign.value ||
763                 validate_expr(state, stmt->v.AnnAssign.value, Load)) &&
764                validate_expr(state, stmt->v.AnnAssign.annotation, Load);
765         break;
766     case For_kind:
767         ret = validate_expr(state, stmt->v.For.target, Store) &&
768             validate_expr(state, stmt->v.For.iter, Load) &&
769             validate_body(state, stmt->v.For.body, "For") &&
770             validate_stmts(state, stmt->v.For.orelse);
771         break;
772     case AsyncFor_kind:
773         ret = validate_expr(state, stmt->v.AsyncFor.target, Store) &&
774             validate_expr(state, stmt->v.AsyncFor.iter, Load) &&
775             validate_body(state, stmt->v.AsyncFor.body, "AsyncFor") &&
776             validate_stmts(state, stmt->v.AsyncFor.orelse);
777         break;
778     case While_kind:
779         ret = validate_expr(state, stmt->v.While.test, Load) &&
780             validate_body(state, stmt->v.While.body, "While") &&
781             validate_stmts(state, stmt->v.While.orelse);
782         break;
783     case If_kind:
784         ret = validate_expr(state, stmt->v.If.test, Load) &&
785             validate_body(state, stmt->v.If.body, "If") &&
786             validate_stmts(state, stmt->v.If.orelse);
787         break;
788     case With_kind:
789         if (!validate_nonempty_seq(stmt->v.With.items, "items", "With"))
790             return 0;
791         for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.With.items); i++) {
792             withitem_ty item = asdl_seq_GET(stmt->v.With.items, i);
793             if (!validate_expr(state, item->context_expr, Load) ||
794                 (item->optional_vars && !validate_expr(state, item->optional_vars, Store)))
795                 return 0;
796         }
797         ret = validate_body(state, stmt->v.With.body, "With");
798         break;
799     case AsyncWith_kind:
800         if (!validate_nonempty_seq(stmt->v.AsyncWith.items, "items", "AsyncWith"))
801             return 0;
802         for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.AsyncWith.items); i++) {
803             withitem_ty item = asdl_seq_GET(stmt->v.AsyncWith.items, i);
804             if (!validate_expr(state, item->context_expr, Load) ||
805                 (item->optional_vars && !validate_expr(state, item->optional_vars, Store)))
806                 return 0;
807         }
808         ret = validate_body(state, stmt->v.AsyncWith.body, "AsyncWith");
809         break;
810     case Match_kind:
811         if (!validate_expr(state, stmt->v.Match.subject, Load)
812             || !validate_nonempty_seq(stmt->v.Match.cases, "cases", "Match")) {
813             return 0;
814         }
815         for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) {
816             match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i);
817             if (!validate_pattern(state, m->pattern, /*star_ok=*/0)
818                 || (m->guard && !validate_expr(state, m->guard, Load))
819                 || !validate_body(state, m->body, "match_case")) {
820                 return 0;
821             }
822         }
823         ret = 1;
824         break;
825     case Raise_kind:
826         if (stmt->v.Raise.exc) {
827             ret = validate_expr(state, stmt->v.Raise.exc, Load) &&
828                 (!stmt->v.Raise.cause || validate_expr(state, stmt->v.Raise.cause, Load));
829             break;
830         }
831         if (stmt->v.Raise.cause) {
832             PyErr_SetString(PyExc_ValueError, "Raise with cause but no exception");
833             return 0;
834         }
835         ret = 1;
836         break;
837     case Try_kind:
838         if (!validate_body(state, stmt->v.Try.body, "Try"))
839             return 0;
840         if (!asdl_seq_LEN(stmt->v.Try.handlers) &&
841             !asdl_seq_LEN(stmt->v.Try.finalbody)) {
842             PyErr_SetString(PyExc_ValueError, "Try has neither except handlers nor finalbody");
843             return 0;
844         }
845         if (!asdl_seq_LEN(stmt->v.Try.handlers) &&
846             asdl_seq_LEN(stmt->v.Try.orelse)) {
847             PyErr_SetString(PyExc_ValueError, "Try has orelse but no except handlers");
848             return 0;
849         }
850         for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.Try.handlers); i++) {
851             excepthandler_ty handler = asdl_seq_GET(stmt->v.Try.handlers, i);
852             VALIDATE_POSITIONS(handler);
853             if ((handler->v.ExceptHandler.type &&
854                  !validate_expr(state, handler->v.ExceptHandler.type, Load)) ||
855                 !validate_body(state, handler->v.ExceptHandler.body, "ExceptHandler"))
856                 return 0;
857         }
858         ret = (!asdl_seq_LEN(stmt->v.Try.finalbody) ||
859                 validate_stmts(state, stmt->v.Try.finalbody)) &&
860             (!asdl_seq_LEN(stmt->v.Try.orelse) ||
861              validate_stmts(state, stmt->v.Try.orelse));
862         break;
863     case TryStar_kind:
864         if (!validate_body(state, stmt->v.TryStar.body, "TryStar"))
865             return 0;
866         if (!asdl_seq_LEN(stmt->v.TryStar.handlers) &&
867             !asdl_seq_LEN(stmt->v.TryStar.finalbody)) {
868             PyErr_SetString(PyExc_ValueError, "TryStar has neither except handlers nor finalbody");
869             return 0;
870         }
871         if (!asdl_seq_LEN(stmt->v.TryStar.handlers) &&
872             asdl_seq_LEN(stmt->v.TryStar.orelse)) {
873             PyErr_SetString(PyExc_ValueError, "TryStar has orelse but no except handlers");
874             return 0;
875         }
876         for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.TryStar.handlers); i++) {
877             excepthandler_ty handler = asdl_seq_GET(stmt->v.TryStar.handlers, i);
878             if ((handler->v.ExceptHandler.type &&
879                  !validate_expr(state, handler->v.ExceptHandler.type, Load)) ||
880                 !validate_body(state, handler->v.ExceptHandler.body, "ExceptHandler"))
881                 return 0;
882         }
883         ret = (!asdl_seq_LEN(stmt->v.TryStar.finalbody) ||
884                 validate_stmts(state, stmt->v.TryStar.finalbody)) &&
885             (!asdl_seq_LEN(stmt->v.TryStar.orelse) ||
886              validate_stmts(state, stmt->v.TryStar.orelse));
887         break;
888     case Assert_kind:
889         ret = validate_expr(state, stmt->v.Assert.test, Load) &&
890             (!stmt->v.Assert.msg || validate_expr(state, stmt->v.Assert.msg, Load));
891         break;
892     case Import_kind:
893         ret = validate_nonempty_seq(stmt->v.Import.names, "names", "Import");
894         break;
895     case ImportFrom_kind:
896         if (stmt->v.ImportFrom.level < 0) {
897             PyErr_SetString(PyExc_ValueError, "Negative ImportFrom level");
898             return 0;
899         }
900         ret = validate_nonempty_seq(stmt->v.ImportFrom.names, "names", "ImportFrom");
901         break;
902     case Global_kind:
903         ret = validate_nonempty_seq(stmt->v.Global.names, "names", "Global");
904         break;
905     case Nonlocal_kind:
906         ret = validate_nonempty_seq(stmt->v.Nonlocal.names, "names", "Nonlocal");
907         break;
908     case Expr_kind:
909         ret = validate_expr(state, stmt->v.Expr.value, Load);
910         break;
911     case AsyncFunctionDef_kind:
912         ret = validate_body(state, stmt->v.AsyncFunctionDef.body, "AsyncFunctionDef") &&
913             validate_arguments(state, stmt->v.AsyncFunctionDef.args) &&
914             validate_exprs(state, stmt->v.AsyncFunctionDef.decorator_list, Load, 0) &&
915             (!stmt->v.AsyncFunctionDef.returns ||
916              validate_expr(state, stmt->v.AsyncFunctionDef.returns, Load));
917         break;
918     case Pass_kind:
919     case Break_kind:
920     case Continue_kind:
921         ret = 1;
922         break;
923     // No default case so compiler emits warning for unhandled cases
924     }
925     if (ret < 0) {
926         PyErr_SetString(PyExc_SystemError, "unexpected statement");
927         ret = 0;
928     }
929     state->recursion_depth--;
930     return ret;
931 }
932 
933 static int
validate_stmts(struct validator * state,asdl_stmt_seq * seq)934 validate_stmts(struct validator *state, asdl_stmt_seq *seq)
935 {
936     assert(!PyErr_Occurred());
937     for (Py_ssize_t i = 0; i < asdl_seq_LEN(seq); i++) {
938         stmt_ty stmt = asdl_seq_GET(seq, i);
939         if (stmt) {
940             if (!validate_stmt(state, stmt))
941                 return 0;
942         }
943         else {
944             PyErr_SetString(PyExc_ValueError,
945                             "None disallowed in statement list");
946             return 0;
947         }
948     }
949     return 1;
950 }
951 
952 static int
validate_exprs(struct validator * state,asdl_expr_seq * exprs,expr_context_ty ctx,int null_ok)953 validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ctx, int null_ok)
954 {
955     assert(!PyErr_Occurred());
956     for (Py_ssize_t i = 0; i < asdl_seq_LEN(exprs); i++) {
957         expr_ty expr = asdl_seq_GET(exprs, i);
958         if (expr) {
959             if (!validate_expr(state, expr, ctx))
960                 return 0;
961         }
962         else if (!null_ok) {
963             PyErr_SetString(PyExc_ValueError,
964                             "None disallowed in expression list");
965             return 0;
966         }
967 
968     }
969     return 1;
970 }
971 
972 static int
validate_patterns(struct validator * state,asdl_pattern_seq * patterns,int star_ok)973 validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ok)
974 {
975     assert(!PyErr_Occurred());
976     for (Py_ssize_t i = 0; i < asdl_seq_LEN(patterns); i++) {
977         pattern_ty pattern = asdl_seq_GET(patterns, i);
978         if (!validate_pattern(state, pattern, star_ok)) {
979             return 0;
980         }
981     }
982     return 1;
983 }
984 
985 
986 /* See comments in symtable.c. */
987 #define COMPILER_STACK_FRAME_SCALE 3
988 
989 int
_PyAST_Validate(mod_ty mod)990 _PyAST_Validate(mod_ty mod)
991 {
992     assert(!PyErr_Occurred());
993     int res = -1;
994     struct validator state;
995     PyThreadState *tstate;
996     int recursion_limit = Py_GetRecursionLimit();
997     int starting_recursion_depth;
998 
999     /* Setup recursion depth check counters */
1000     tstate = _PyThreadState_GET();
1001     if (!tstate) {
1002         return 0;
1003     }
1004     /* Be careful here to prevent overflow. */
1005     int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
1006     starting_recursion_depth = (recursion_depth< INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1007         recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
1008     state.recursion_depth = starting_recursion_depth;
1009     state.recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1010         recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
1011 
1012     switch (mod->kind) {
1013     case Module_kind:
1014         res = validate_stmts(&state, mod->v.Module.body);
1015         break;
1016     case Interactive_kind:
1017         res = validate_stmts(&state, mod->v.Interactive.body);
1018         break;
1019     case Expression_kind:
1020         res = validate_expr(&state, mod->v.Expression.body, Load);
1021         break;
1022     case FunctionType_kind:
1023         res = validate_exprs(&state, mod->v.FunctionType.argtypes, Load, /*null_ok=*/0) &&
1024               validate_expr(&state, mod->v.FunctionType.returns, Load);
1025         break;
1026     // No default case so compiler emits warning for unhandled cases
1027     }
1028 
1029     if (res < 0) {
1030         PyErr_SetString(PyExc_SystemError, "impossible module node");
1031         return 0;
1032     }
1033 
1034     /* Check that the recursion depth counting balanced correctly */
1035     if (res && state.recursion_depth != starting_recursion_depth) {
1036         PyErr_Format(PyExc_SystemError,
1037             "AST validator recursion depth mismatch (before=%d, after=%d)",
1038             starting_recursion_depth, state.recursion_depth);
1039         return 0;
1040     }
1041     return res;
1042 }
1043 
1044 PyObject *
_PyAST_GetDocString(asdl_stmt_seq * body)1045 _PyAST_GetDocString(asdl_stmt_seq *body)
1046 {
1047     if (!asdl_seq_LEN(body)) {
1048         return NULL;
1049     }
1050     stmt_ty st = asdl_seq_GET(body, 0);
1051     if (st->kind != Expr_kind) {
1052         return NULL;
1053     }
1054     expr_ty e = st->v.Expr.value;
1055     if (e->kind == Constant_kind && PyUnicode_CheckExact(e->v.Constant.value)) {
1056         return e->v.Constant.value;
1057     }
1058     return NULL;
1059 }
1060