1 /*
2 * This file exposes PyAST_Validate interface to check the integrity
3 * of the given abstract syntax tree (potentially constructed manually).
4 */
5 #include "Python.h"
6 #include "pycore_ast.h" // asdl_stmt_seq
7 #include "pycore_pystate.h" // _PyThreadState_GET()
8
9 #include <assert.h>
10 #include <stdbool.h>
11
12 struct validator {
13 int recursion_depth; /* current recursion depth */
14 int recursion_limit; /* recursion limit */
15 };
16
17 static int validate_stmts(struct validator *, asdl_stmt_seq *);
18 static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int);
19 static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
20 static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
21 static int validate_stmt(struct validator *, stmt_ty);
22 static int validate_expr(struct validator *, expr_ty, expr_context_ty);
23 static int validate_pattern(struct validator *, pattern_ty, int);
24
25 #define VALIDATE_POSITIONS(node) \
26 if (node->lineno > node->end_lineno) { \
27 PyErr_Format(PyExc_ValueError, \
28 "AST node line range (%d, %d) is not valid", \
29 node->lineno, node->end_lineno); \
30 return 0; \
31 } \
32 if ((node->lineno < 0 && node->end_lineno != node->lineno) || \
33 (node->col_offset < 0 && node->col_offset != node->end_col_offset)) { \
34 PyErr_Format(PyExc_ValueError, \
35 "AST node column range (%d, %d) for line range (%d, %d) is not valid", \
36 node->col_offset, node->end_col_offset, node->lineno, node->end_lineno); \
37 return 0; \
38 } \
39 if (node->lineno == node->end_lineno && node->col_offset > node->end_col_offset) { \
40 PyErr_Format(PyExc_ValueError, \
41 "line %d, column %d-%d is not a valid range", \
42 node->lineno, node->col_offset, node->end_col_offset); \
43 return 0; \
44 }
45
46 static int
validate_name(PyObject * name)47 validate_name(PyObject *name)
48 {
49 assert(!PyErr_Occurred());
50 assert(PyUnicode_Check(name));
51 static const char * const forbidden[] = {
52 "None",
53 "True",
54 "False",
55 NULL
56 };
57 for (int i = 0; forbidden[i] != NULL; i++) {
58 if (_PyUnicode_EqualToASCIIString(name, forbidden[i])) {
59 PyErr_Format(PyExc_ValueError, "identifier field can't represent '%s' constant", forbidden[i]);
60 return 0;
61 }
62 }
63 return 1;
64 }
65
66 static int
validate_comprehension(struct validator * state,asdl_comprehension_seq * gens)67 validate_comprehension(struct validator *state, asdl_comprehension_seq *gens)
68 {
69 assert(!PyErr_Occurred());
70 if (!asdl_seq_LEN(gens)) {
71 PyErr_SetString(PyExc_ValueError, "comprehension with no generators");
72 return 0;
73 }
74 for (Py_ssize_t i = 0; i < asdl_seq_LEN(gens); i++) {
75 comprehension_ty comp = asdl_seq_GET(gens, i);
76 if (!validate_expr(state, comp->target, Store) ||
77 !validate_expr(state, comp->iter, Load) ||
78 !validate_exprs(state, comp->ifs, Load, 0))
79 return 0;
80 }
81 return 1;
82 }
83
84 static int
validate_keywords(struct validator * state,asdl_keyword_seq * keywords)85 validate_keywords(struct validator *state, asdl_keyword_seq *keywords)
86 {
87 assert(!PyErr_Occurred());
88 for (Py_ssize_t i = 0; i < asdl_seq_LEN(keywords); i++)
89 if (!validate_expr(state, (asdl_seq_GET(keywords, i))->value, Load))
90 return 0;
91 return 1;
92 }
93
94 static int
validate_args(struct validator * state,asdl_arg_seq * args)95 validate_args(struct validator *state, asdl_arg_seq *args)
96 {
97 assert(!PyErr_Occurred());
98 for (Py_ssize_t i = 0; i < asdl_seq_LEN(args); i++) {
99 arg_ty arg = asdl_seq_GET(args, i);
100 VALIDATE_POSITIONS(arg);
101 if (arg->annotation && !validate_expr(state, arg->annotation, Load))
102 return 0;
103 }
104 return 1;
105 }
106
107 static const char *
expr_context_name(expr_context_ty ctx)108 expr_context_name(expr_context_ty ctx)
109 {
110 switch (ctx) {
111 case Load:
112 return "Load";
113 case Store:
114 return "Store";
115 case Del:
116 return "Del";
117 // No default case so compiler emits warning for unhandled cases
118 }
119 Py_UNREACHABLE();
120 }
121
122 static int
validate_arguments(struct validator * state,arguments_ty args)123 validate_arguments(struct validator *state, arguments_ty args)
124 {
125 assert(!PyErr_Occurred());
126 if (!validate_args(state, args->posonlyargs) || !validate_args(state, args->args)) {
127 return 0;
128 }
129 if (args->vararg && args->vararg->annotation
130 && !validate_expr(state, args->vararg->annotation, Load)) {
131 return 0;
132 }
133 if (!validate_args(state, args->kwonlyargs))
134 return 0;
135 if (args->kwarg && args->kwarg->annotation
136 && !validate_expr(state, args->kwarg->annotation, Load)) {
137 return 0;
138 }
139 if (asdl_seq_LEN(args->defaults) > asdl_seq_LEN(args->posonlyargs) + asdl_seq_LEN(args->args)) {
140 PyErr_SetString(PyExc_ValueError, "more positional defaults than args on arguments");
141 return 0;
142 }
143 if (asdl_seq_LEN(args->kw_defaults) != asdl_seq_LEN(args->kwonlyargs)) {
144 PyErr_SetString(PyExc_ValueError, "length of kwonlyargs is not the same as "
145 "kw_defaults on arguments");
146 return 0;
147 }
148 return validate_exprs(state, args->defaults, Load, 0) && validate_exprs(state, args->kw_defaults, Load, 1);
149 }
150
151 static int
validate_constant(struct validator * state,PyObject * value)152 validate_constant(struct validator *state, PyObject *value)
153 {
154 assert(!PyErr_Occurred());
155 if (value == Py_None || value == Py_Ellipsis)
156 return 1;
157
158 if (PyLong_CheckExact(value)
159 || PyFloat_CheckExact(value)
160 || PyComplex_CheckExact(value)
161 || PyBool_Check(value)
162 || PyUnicode_CheckExact(value)
163 || PyBytes_CheckExact(value))
164 return 1;
165
166 if (PyTuple_CheckExact(value) || PyFrozenSet_CheckExact(value)) {
167 if (++state->recursion_depth > state->recursion_limit) {
168 PyErr_SetString(PyExc_RecursionError,
169 "maximum recursion depth exceeded during compilation");
170 return 0;
171 }
172
173 PyObject *it = PyObject_GetIter(value);
174 if (it == NULL)
175 return 0;
176
177 while (1) {
178 PyObject *item = PyIter_Next(it);
179 if (item == NULL) {
180 if (PyErr_Occurred()) {
181 Py_DECREF(it);
182 return 0;
183 }
184 break;
185 }
186
187 if (!validate_constant(state, item)) {
188 Py_DECREF(it);
189 Py_DECREF(item);
190 return 0;
191 }
192 Py_DECREF(item);
193 }
194
195 Py_DECREF(it);
196 --state->recursion_depth;
197 return 1;
198 }
199
200 if (!PyErr_Occurred()) {
201 PyErr_Format(PyExc_TypeError,
202 "got an invalid type in Constant: %s",
203 _PyType_Name(Py_TYPE(value)));
204 }
205 return 0;
206 }
207
208 static int
validate_expr(struct validator * state,expr_ty exp,expr_context_ty ctx)209 validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
210 {
211 assert(!PyErr_Occurred());
212 VALIDATE_POSITIONS(exp);
213 int ret = -1;
214 if (++state->recursion_depth > state->recursion_limit) {
215 PyErr_SetString(PyExc_RecursionError,
216 "maximum recursion depth exceeded during compilation");
217 return 0;
218 }
219 int check_ctx = 1;
220 expr_context_ty actual_ctx;
221
222 /* First check expression context. */
223 switch (exp->kind) {
224 case Attribute_kind:
225 actual_ctx = exp->v.Attribute.ctx;
226 break;
227 case Subscript_kind:
228 actual_ctx = exp->v.Subscript.ctx;
229 break;
230 case Starred_kind:
231 actual_ctx = exp->v.Starred.ctx;
232 break;
233 case Name_kind:
234 if (!validate_name(exp->v.Name.id)) {
235 return 0;
236 }
237 actual_ctx = exp->v.Name.ctx;
238 break;
239 case List_kind:
240 actual_ctx = exp->v.List.ctx;
241 break;
242 case Tuple_kind:
243 actual_ctx = exp->v.Tuple.ctx;
244 break;
245 default:
246 if (ctx != Load) {
247 PyErr_Format(PyExc_ValueError, "expression which can't be "
248 "assigned to in %s context", expr_context_name(ctx));
249 return 0;
250 }
251 check_ctx = 0;
252 /* set actual_ctx to prevent gcc warning */
253 actual_ctx = 0;
254 }
255 if (check_ctx && actual_ctx != ctx) {
256 PyErr_Format(PyExc_ValueError, "expression must have %s context but has %s instead",
257 expr_context_name(ctx), expr_context_name(actual_ctx));
258 return 0;
259 }
260
261 /* Now validate expression. */
262 switch (exp->kind) {
263 case BoolOp_kind:
264 if (asdl_seq_LEN(exp->v.BoolOp.values) < 2) {
265 PyErr_SetString(PyExc_ValueError, "BoolOp with less than 2 values");
266 return 0;
267 }
268 ret = validate_exprs(state, exp->v.BoolOp.values, Load, 0);
269 break;
270 case BinOp_kind:
271 ret = validate_expr(state, exp->v.BinOp.left, Load) &&
272 validate_expr(state, exp->v.BinOp.right, Load);
273 break;
274 case UnaryOp_kind:
275 ret = validate_expr(state, exp->v.UnaryOp.operand, Load);
276 break;
277 case Lambda_kind:
278 ret = validate_arguments(state, exp->v.Lambda.args) &&
279 validate_expr(state, exp->v.Lambda.body, Load);
280 break;
281 case IfExp_kind:
282 ret = validate_expr(state, exp->v.IfExp.test, Load) &&
283 validate_expr(state, exp->v.IfExp.body, Load) &&
284 validate_expr(state, exp->v.IfExp.orelse, Load);
285 break;
286 case Dict_kind:
287 if (asdl_seq_LEN(exp->v.Dict.keys) != asdl_seq_LEN(exp->v.Dict.values)) {
288 PyErr_SetString(PyExc_ValueError,
289 "Dict doesn't have the same number of keys as values");
290 return 0;
291 }
292 /* null_ok=1 for keys expressions to allow dict unpacking to work in
293 dict literals, i.e. ``{**{a:b}}`` */
294 ret = validate_exprs(state, exp->v.Dict.keys, Load, /*null_ok=*/ 1) &&
295 validate_exprs(state, exp->v.Dict.values, Load, /*null_ok=*/ 0);
296 break;
297 case Set_kind:
298 ret = validate_exprs(state, exp->v.Set.elts, Load, 0);
299 break;
300 #define COMP(NAME) \
301 case NAME ## _kind: \
302 ret = validate_comprehension(state, exp->v.NAME.generators) && \
303 validate_expr(state, exp->v.NAME.elt, Load); \
304 break;
305 COMP(ListComp)
306 COMP(SetComp)
307 COMP(GeneratorExp)
308 #undef COMP
309 case DictComp_kind:
310 ret = validate_comprehension(state, exp->v.DictComp.generators) &&
311 validate_expr(state, exp->v.DictComp.key, Load) &&
312 validate_expr(state, exp->v.DictComp.value, Load);
313 break;
314 case Yield_kind:
315 ret = !exp->v.Yield.value || validate_expr(state, exp->v.Yield.value, Load);
316 break;
317 case YieldFrom_kind:
318 ret = validate_expr(state, exp->v.YieldFrom.value, Load);
319 break;
320 case Await_kind:
321 ret = validate_expr(state, exp->v.Await.value, Load);
322 break;
323 case Compare_kind:
324 if (!asdl_seq_LEN(exp->v.Compare.comparators)) {
325 PyErr_SetString(PyExc_ValueError, "Compare with no comparators");
326 return 0;
327 }
328 if (asdl_seq_LEN(exp->v.Compare.comparators) !=
329 asdl_seq_LEN(exp->v.Compare.ops)) {
330 PyErr_SetString(PyExc_ValueError, "Compare has a different number "
331 "of comparators and operands");
332 return 0;
333 }
334 ret = validate_exprs(state, exp->v.Compare.comparators, Load, 0) &&
335 validate_expr(state, exp->v.Compare.left, Load);
336 break;
337 case Call_kind:
338 ret = validate_expr(state, exp->v.Call.func, Load) &&
339 validate_exprs(state, exp->v.Call.args, Load, 0) &&
340 validate_keywords(state, exp->v.Call.keywords);
341 break;
342 case Constant_kind:
343 if (!validate_constant(state, exp->v.Constant.value)) {
344 return 0;
345 }
346 ret = 1;
347 break;
348 case JoinedStr_kind:
349 ret = validate_exprs(state, exp->v.JoinedStr.values, Load, 0);
350 break;
351 case FormattedValue_kind:
352 if (validate_expr(state, exp->v.FormattedValue.value, Load) == 0)
353 return 0;
354 if (exp->v.FormattedValue.format_spec) {
355 ret = validate_expr(state, exp->v.FormattedValue.format_spec, Load);
356 break;
357 }
358 ret = 1;
359 break;
360 case Attribute_kind:
361 ret = validate_expr(state, exp->v.Attribute.value, Load);
362 break;
363 case Subscript_kind:
364 ret = validate_expr(state, exp->v.Subscript.slice, Load) &&
365 validate_expr(state, exp->v.Subscript.value, Load);
366 break;
367 case Starred_kind:
368 ret = validate_expr(state, exp->v.Starred.value, ctx);
369 break;
370 case Slice_kind:
371 ret = (!exp->v.Slice.lower || validate_expr(state, exp->v.Slice.lower, Load)) &&
372 (!exp->v.Slice.upper || validate_expr(state, exp->v.Slice.upper, Load)) &&
373 (!exp->v.Slice.step || validate_expr(state, exp->v.Slice.step, Load));
374 break;
375 case List_kind:
376 ret = validate_exprs(state, exp->v.List.elts, ctx, 0);
377 break;
378 case Tuple_kind:
379 ret = validate_exprs(state, exp->v.Tuple.elts, ctx, 0);
380 break;
381 case NamedExpr_kind:
382 ret = validate_expr(state, exp->v.NamedExpr.value, Load);
383 break;
384 /* This last case doesn't have any checking. */
385 case Name_kind:
386 ret = 1;
387 break;
388 // No default case so compiler emits warning for unhandled cases
389 }
390 if (ret < 0) {
391 PyErr_SetString(PyExc_SystemError, "unexpected expression");
392 ret = 0;
393 }
394 state->recursion_depth--;
395 return ret;
396 }
397
398
399 // Note: the ensure_literal_* functions are only used to validate a restricted
400 // set of non-recursive literals that have already been checked with
401 // validate_expr, so they don't accept the validator state
402 static int
ensure_literal_number(expr_ty exp,bool allow_real,bool allow_imaginary)403 ensure_literal_number(expr_ty exp, bool allow_real, bool allow_imaginary)
404 {
405 assert(exp->kind == Constant_kind);
406 PyObject *value = exp->v.Constant.value;
407 return (allow_real && PyFloat_CheckExact(value)) ||
408 (allow_real && PyLong_CheckExact(value)) ||
409 (allow_imaginary && PyComplex_CheckExact(value));
410 }
411
412 static int
ensure_literal_negative(expr_ty exp,bool allow_real,bool allow_imaginary)413 ensure_literal_negative(expr_ty exp, bool allow_real, bool allow_imaginary)
414 {
415 assert(exp->kind == UnaryOp_kind);
416 // Must be negation ...
417 if (exp->v.UnaryOp.op != USub) {
418 return 0;
419 }
420 // ... of a constant ...
421 expr_ty operand = exp->v.UnaryOp.operand;
422 if (operand->kind != Constant_kind) {
423 return 0;
424 }
425 // ... number
426 return ensure_literal_number(operand, allow_real, allow_imaginary);
427 }
428
429 static int
ensure_literal_complex(expr_ty exp)430 ensure_literal_complex(expr_ty exp)
431 {
432 assert(exp->kind == BinOp_kind);
433 expr_ty left = exp->v.BinOp.left;
434 expr_ty right = exp->v.BinOp.right;
435 // Ensure op is addition or subtraction
436 if (exp->v.BinOp.op != Add && exp->v.BinOp.op != Sub) {
437 return 0;
438 }
439 // Check LHS is a real number (potentially signed)
440 switch (left->kind)
441 {
442 case Constant_kind:
443 if (!ensure_literal_number(left, /*real=*/true, /*imaginary=*/false)) {
444 return 0;
445 }
446 break;
447 case UnaryOp_kind:
448 if (!ensure_literal_negative(left, /*real=*/true, /*imaginary=*/false)) {
449 return 0;
450 }
451 break;
452 default:
453 return 0;
454 }
455 // Check RHS is an imaginary number (no separate sign allowed)
456 switch (right->kind)
457 {
458 case Constant_kind:
459 if (!ensure_literal_number(right, /*real=*/false, /*imaginary=*/true)) {
460 return 0;
461 }
462 break;
463 default:
464 return 0;
465 }
466 return 1;
467 }
468
469 static int
validate_pattern_match_value(struct validator * state,expr_ty exp)470 validate_pattern_match_value(struct validator *state, expr_ty exp)
471 {
472 assert(!PyErr_Occurred());
473 if (!validate_expr(state, exp, Load)) {
474 return 0;
475 }
476
477 switch (exp->kind)
478 {
479 case Constant_kind:
480 /* Ellipsis and immutable sequences are not allowed.
481 For True, False and None, MatchSingleton() should
482 be used */
483 if (!validate_expr(state, exp, Load)) {
484 return 0;
485 }
486 PyObject *literal = exp->v.Constant.value;
487 if (PyLong_CheckExact(literal) || PyFloat_CheckExact(literal) ||
488 PyBytes_CheckExact(literal) || PyComplex_CheckExact(literal) ||
489 PyUnicode_CheckExact(literal)) {
490 return 1;
491 }
492 PyErr_SetString(PyExc_ValueError,
493 "unexpected constant inside of a literal pattern");
494 return 0;
495 case Attribute_kind:
496 // Constants and attribute lookups are always permitted
497 return 1;
498 case UnaryOp_kind:
499 // Negated numbers are permitted (whether real or imaginary)
500 // Compiler will complain if AST folding doesn't create a constant
501 if (ensure_literal_negative(exp, /*real=*/true, /*imaginary=*/true)) {
502 return 1;
503 }
504 break;
505 case BinOp_kind:
506 // Complex literals are permitted
507 // Compiler will complain if AST folding doesn't create a constant
508 if (ensure_literal_complex(exp)) {
509 return 1;
510 }
511 break;
512 case JoinedStr_kind:
513 // Handled in the later stages
514 return 1;
515 default:
516 break;
517 }
518 PyErr_SetString(PyExc_ValueError,
519 "patterns may only match literals and attribute lookups");
520 return 0;
521 }
522
523 static int
validate_capture(PyObject * name)524 validate_capture(PyObject *name)
525 {
526 assert(!PyErr_Occurred());
527 if (_PyUnicode_EqualToASCIIString(name, "_")) {
528 PyErr_Format(PyExc_ValueError, "can't capture name '_' in patterns");
529 return 0;
530 }
531 return validate_name(name);
532 }
533
534 static int
validate_pattern(struct validator * state,pattern_ty p,int star_ok)535 validate_pattern(struct validator *state, pattern_ty p, int star_ok)
536 {
537 assert(!PyErr_Occurred());
538 VALIDATE_POSITIONS(p);
539 int ret = -1;
540 if (++state->recursion_depth > state->recursion_limit) {
541 PyErr_SetString(PyExc_RecursionError,
542 "maximum recursion depth exceeded during compilation");
543 return 0;
544 }
545 switch (p->kind) {
546 case MatchValue_kind:
547 ret = validate_pattern_match_value(state, p->v.MatchValue.value);
548 break;
549 case MatchSingleton_kind:
550 ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value);
551 if (!ret) {
552 PyErr_SetString(PyExc_ValueError,
553 "MatchSingleton can only contain True, False and None");
554 }
555 break;
556 case MatchSequence_kind:
557 ret = validate_patterns(state, p->v.MatchSequence.patterns, /*star_ok=*/1);
558 break;
559 case MatchMapping_kind:
560 if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) {
561 PyErr_SetString(PyExc_ValueError,
562 "MatchMapping doesn't have the same number of keys as patterns");
563 ret = 0;
564 break;
565 }
566
567 if (p->v.MatchMapping.rest && !validate_capture(p->v.MatchMapping.rest)) {
568 ret = 0;
569 break;
570 }
571
572 asdl_expr_seq *keys = p->v.MatchMapping.keys;
573 for (Py_ssize_t i = 0; i < asdl_seq_LEN(keys); i++) {
574 expr_ty key = asdl_seq_GET(keys, i);
575 if (key->kind == Constant_kind) {
576 PyObject *literal = key->v.Constant.value;
577 if (literal == Py_None || PyBool_Check(literal)) {
578 /* validate_pattern_match_value will ensure the key
579 doesn't contain True, False and None but it is
580 syntactically valid, so we will pass those on in
581 a special case. */
582 continue;
583 }
584 }
585 if (!validate_pattern_match_value(state, key)) {
586 ret = 0;
587 break;
588 }
589 }
590 if (ret == 0) {
591 break;
592 }
593 ret = validate_patterns(state, p->v.MatchMapping.patterns, /*star_ok=*/0);
594 break;
595 case MatchClass_kind:
596 if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) {
597 PyErr_SetString(PyExc_ValueError,
598 "MatchClass doesn't have the same number of keyword attributes as patterns");
599 ret = 0;
600 break;
601 }
602 if (!validate_expr(state, p->v.MatchClass.cls, Load)) {
603 ret = 0;
604 break;
605 }
606
607 expr_ty cls = p->v.MatchClass.cls;
608 while (1) {
609 if (cls->kind == Name_kind) {
610 break;
611 }
612 else if (cls->kind == Attribute_kind) {
613 cls = cls->v.Attribute.value;
614 continue;
615 }
616 else {
617 PyErr_SetString(PyExc_ValueError,
618 "MatchClass cls field can only contain Name or Attribute nodes.");
619 ret = 0;
620 break;
621 }
622 }
623 if (ret == 0) {
624 break;
625 }
626
627 for (Py_ssize_t i = 0; i < asdl_seq_LEN(p->v.MatchClass.kwd_attrs); i++) {
628 PyObject *identifier = asdl_seq_GET(p->v.MatchClass.kwd_attrs, i);
629 if (!validate_name(identifier)) {
630 ret = 0;
631 break;
632 }
633 }
634 if (ret == 0) {
635 break;
636 }
637
638 if (!validate_patterns(state, p->v.MatchClass.patterns, /*star_ok=*/0)) {
639 ret = 0;
640 break;
641 }
642
643 ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0);
644 break;
645 case MatchStar_kind:
646 if (!star_ok) {
647 PyErr_SetString(PyExc_ValueError, "can't use MatchStar here");
648 ret = 0;
649 break;
650 }
651 ret = p->v.MatchStar.name == NULL || validate_capture(p->v.MatchStar.name);
652 break;
653 case MatchAs_kind:
654 if (p->v.MatchAs.name && !validate_capture(p->v.MatchAs.name)) {
655 ret = 0;
656 break;
657 }
658 if (p->v.MatchAs.pattern == NULL) {
659 ret = 1;
660 }
661 else if (p->v.MatchAs.name == NULL) {
662 PyErr_SetString(PyExc_ValueError,
663 "MatchAs must specify a target name if a pattern is given");
664 ret = 0;
665 }
666 else {
667 ret = validate_pattern(state, p->v.MatchAs.pattern, /*star_ok=*/0);
668 }
669 break;
670 case MatchOr_kind:
671 if (asdl_seq_LEN(p->v.MatchOr.patterns) < 2) {
672 PyErr_SetString(PyExc_ValueError,
673 "MatchOr requires at least 2 patterns");
674 ret = 0;
675 break;
676 }
677 ret = validate_patterns(state, p->v.MatchOr.patterns, /*star_ok=*/0);
678 break;
679 // No default case, so the compiler will emit a warning if new pattern
680 // kinds are added without being handled here
681 }
682 if (ret < 0) {
683 PyErr_SetString(PyExc_SystemError, "unexpected pattern");
684 ret = 0;
685 }
686 state->recursion_depth--;
687 return ret;
688 }
689
690 static int
_validate_nonempty_seq(asdl_seq * seq,const char * what,const char * owner)691 _validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner)
692 {
693 if (asdl_seq_LEN(seq))
694 return 1;
695 PyErr_Format(PyExc_ValueError, "empty %s on %s", what, owner);
696 return 0;
697 }
698 #define validate_nonempty_seq(seq, what, owner) _validate_nonempty_seq((asdl_seq*)seq, what, owner)
699
700 static int
validate_assignlist(struct validator * state,asdl_expr_seq * targets,expr_context_ty ctx)701 validate_assignlist(struct validator *state, asdl_expr_seq *targets, expr_context_ty ctx)
702 {
703 assert(!PyErr_Occurred());
704 return validate_nonempty_seq(targets, "targets", ctx == Del ? "Delete" : "Assign") &&
705 validate_exprs(state, targets, ctx, 0);
706 }
707
708 static int
validate_body(struct validator * state,asdl_stmt_seq * body,const char * owner)709 validate_body(struct validator *state, asdl_stmt_seq *body, const char *owner)
710 {
711 assert(!PyErr_Occurred());
712 return validate_nonempty_seq(body, "body", owner) && validate_stmts(state, body);
713 }
714
715 static int
validate_stmt(struct validator * state,stmt_ty stmt)716 validate_stmt(struct validator *state, stmt_ty stmt)
717 {
718 assert(!PyErr_Occurred());
719 VALIDATE_POSITIONS(stmt);
720 int ret = -1;
721 if (++state->recursion_depth > state->recursion_limit) {
722 PyErr_SetString(PyExc_RecursionError,
723 "maximum recursion depth exceeded during compilation");
724 return 0;
725 }
726 switch (stmt->kind) {
727 case FunctionDef_kind:
728 ret = validate_body(state, stmt->v.FunctionDef.body, "FunctionDef") &&
729 validate_arguments(state, stmt->v.FunctionDef.args) &&
730 validate_exprs(state, stmt->v.FunctionDef.decorator_list, Load, 0) &&
731 (!stmt->v.FunctionDef.returns ||
732 validate_expr(state, stmt->v.FunctionDef.returns, Load));
733 break;
734 case ClassDef_kind:
735 ret = validate_body(state, stmt->v.ClassDef.body, "ClassDef") &&
736 validate_exprs(state, stmt->v.ClassDef.bases, Load, 0) &&
737 validate_keywords(state, stmt->v.ClassDef.keywords) &&
738 validate_exprs(state, stmt->v.ClassDef.decorator_list, Load, 0);
739 break;
740 case Return_kind:
741 ret = !stmt->v.Return.value || validate_expr(state, stmt->v.Return.value, Load);
742 break;
743 case Delete_kind:
744 ret = validate_assignlist(state, stmt->v.Delete.targets, Del);
745 break;
746 case Assign_kind:
747 ret = validate_assignlist(state, stmt->v.Assign.targets, Store) &&
748 validate_expr(state, stmt->v.Assign.value, Load);
749 break;
750 case AugAssign_kind:
751 ret = validate_expr(state, stmt->v.AugAssign.target, Store) &&
752 validate_expr(state, stmt->v.AugAssign.value, Load);
753 break;
754 case AnnAssign_kind:
755 if (stmt->v.AnnAssign.target->kind != Name_kind &&
756 stmt->v.AnnAssign.simple) {
757 PyErr_SetString(PyExc_TypeError,
758 "AnnAssign with simple non-Name target");
759 return 0;
760 }
761 ret = validate_expr(state, stmt->v.AnnAssign.target, Store) &&
762 (!stmt->v.AnnAssign.value ||
763 validate_expr(state, stmt->v.AnnAssign.value, Load)) &&
764 validate_expr(state, stmt->v.AnnAssign.annotation, Load);
765 break;
766 case For_kind:
767 ret = validate_expr(state, stmt->v.For.target, Store) &&
768 validate_expr(state, stmt->v.For.iter, Load) &&
769 validate_body(state, stmt->v.For.body, "For") &&
770 validate_stmts(state, stmt->v.For.orelse);
771 break;
772 case AsyncFor_kind:
773 ret = validate_expr(state, stmt->v.AsyncFor.target, Store) &&
774 validate_expr(state, stmt->v.AsyncFor.iter, Load) &&
775 validate_body(state, stmt->v.AsyncFor.body, "AsyncFor") &&
776 validate_stmts(state, stmt->v.AsyncFor.orelse);
777 break;
778 case While_kind:
779 ret = validate_expr(state, stmt->v.While.test, Load) &&
780 validate_body(state, stmt->v.While.body, "While") &&
781 validate_stmts(state, stmt->v.While.orelse);
782 break;
783 case If_kind:
784 ret = validate_expr(state, stmt->v.If.test, Load) &&
785 validate_body(state, stmt->v.If.body, "If") &&
786 validate_stmts(state, stmt->v.If.orelse);
787 break;
788 case With_kind:
789 if (!validate_nonempty_seq(stmt->v.With.items, "items", "With"))
790 return 0;
791 for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.With.items); i++) {
792 withitem_ty item = asdl_seq_GET(stmt->v.With.items, i);
793 if (!validate_expr(state, item->context_expr, Load) ||
794 (item->optional_vars && !validate_expr(state, item->optional_vars, Store)))
795 return 0;
796 }
797 ret = validate_body(state, stmt->v.With.body, "With");
798 break;
799 case AsyncWith_kind:
800 if (!validate_nonempty_seq(stmt->v.AsyncWith.items, "items", "AsyncWith"))
801 return 0;
802 for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.AsyncWith.items); i++) {
803 withitem_ty item = asdl_seq_GET(stmt->v.AsyncWith.items, i);
804 if (!validate_expr(state, item->context_expr, Load) ||
805 (item->optional_vars && !validate_expr(state, item->optional_vars, Store)))
806 return 0;
807 }
808 ret = validate_body(state, stmt->v.AsyncWith.body, "AsyncWith");
809 break;
810 case Match_kind:
811 if (!validate_expr(state, stmt->v.Match.subject, Load)
812 || !validate_nonempty_seq(stmt->v.Match.cases, "cases", "Match")) {
813 return 0;
814 }
815 for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) {
816 match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i);
817 if (!validate_pattern(state, m->pattern, /*star_ok=*/0)
818 || (m->guard && !validate_expr(state, m->guard, Load))
819 || !validate_body(state, m->body, "match_case")) {
820 return 0;
821 }
822 }
823 ret = 1;
824 break;
825 case Raise_kind:
826 if (stmt->v.Raise.exc) {
827 ret = validate_expr(state, stmt->v.Raise.exc, Load) &&
828 (!stmt->v.Raise.cause || validate_expr(state, stmt->v.Raise.cause, Load));
829 break;
830 }
831 if (stmt->v.Raise.cause) {
832 PyErr_SetString(PyExc_ValueError, "Raise with cause but no exception");
833 return 0;
834 }
835 ret = 1;
836 break;
837 case Try_kind:
838 if (!validate_body(state, stmt->v.Try.body, "Try"))
839 return 0;
840 if (!asdl_seq_LEN(stmt->v.Try.handlers) &&
841 !asdl_seq_LEN(stmt->v.Try.finalbody)) {
842 PyErr_SetString(PyExc_ValueError, "Try has neither except handlers nor finalbody");
843 return 0;
844 }
845 if (!asdl_seq_LEN(stmt->v.Try.handlers) &&
846 asdl_seq_LEN(stmt->v.Try.orelse)) {
847 PyErr_SetString(PyExc_ValueError, "Try has orelse but no except handlers");
848 return 0;
849 }
850 for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.Try.handlers); i++) {
851 excepthandler_ty handler = asdl_seq_GET(stmt->v.Try.handlers, i);
852 VALIDATE_POSITIONS(handler);
853 if ((handler->v.ExceptHandler.type &&
854 !validate_expr(state, handler->v.ExceptHandler.type, Load)) ||
855 !validate_body(state, handler->v.ExceptHandler.body, "ExceptHandler"))
856 return 0;
857 }
858 ret = (!asdl_seq_LEN(stmt->v.Try.finalbody) ||
859 validate_stmts(state, stmt->v.Try.finalbody)) &&
860 (!asdl_seq_LEN(stmt->v.Try.orelse) ||
861 validate_stmts(state, stmt->v.Try.orelse));
862 break;
863 case TryStar_kind:
864 if (!validate_body(state, stmt->v.TryStar.body, "TryStar"))
865 return 0;
866 if (!asdl_seq_LEN(stmt->v.TryStar.handlers) &&
867 !asdl_seq_LEN(stmt->v.TryStar.finalbody)) {
868 PyErr_SetString(PyExc_ValueError, "TryStar has neither except handlers nor finalbody");
869 return 0;
870 }
871 if (!asdl_seq_LEN(stmt->v.TryStar.handlers) &&
872 asdl_seq_LEN(stmt->v.TryStar.orelse)) {
873 PyErr_SetString(PyExc_ValueError, "TryStar has orelse but no except handlers");
874 return 0;
875 }
876 for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.TryStar.handlers); i++) {
877 excepthandler_ty handler = asdl_seq_GET(stmt->v.TryStar.handlers, i);
878 if ((handler->v.ExceptHandler.type &&
879 !validate_expr(state, handler->v.ExceptHandler.type, Load)) ||
880 !validate_body(state, handler->v.ExceptHandler.body, "ExceptHandler"))
881 return 0;
882 }
883 ret = (!asdl_seq_LEN(stmt->v.TryStar.finalbody) ||
884 validate_stmts(state, stmt->v.TryStar.finalbody)) &&
885 (!asdl_seq_LEN(stmt->v.TryStar.orelse) ||
886 validate_stmts(state, stmt->v.TryStar.orelse));
887 break;
888 case Assert_kind:
889 ret = validate_expr(state, stmt->v.Assert.test, Load) &&
890 (!stmt->v.Assert.msg || validate_expr(state, stmt->v.Assert.msg, Load));
891 break;
892 case Import_kind:
893 ret = validate_nonempty_seq(stmt->v.Import.names, "names", "Import");
894 break;
895 case ImportFrom_kind:
896 if (stmt->v.ImportFrom.level < 0) {
897 PyErr_SetString(PyExc_ValueError, "Negative ImportFrom level");
898 return 0;
899 }
900 ret = validate_nonempty_seq(stmt->v.ImportFrom.names, "names", "ImportFrom");
901 break;
902 case Global_kind:
903 ret = validate_nonempty_seq(stmt->v.Global.names, "names", "Global");
904 break;
905 case Nonlocal_kind:
906 ret = validate_nonempty_seq(stmt->v.Nonlocal.names, "names", "Nonlocal");
907 break;
908 case Expr_kind:
909 ret = validate_expr(state, stmt->v.Expr.value, Load);
910 break;
911 case AsyncFunctionDef_kind:
912 ret = validate_body(state, stmt->v.AsyncFunctionDef.body, "AsyncFunctionDef") &&
913 validate_arguments(state, stmt->v.AsyncFunctionDef.args) &&
914 validate_exprs(state, stmt->v.AsyncFunctionDef.decorator_list, Load, 0) &&
915 (!stmt->v.AsyncFunctionDef.returns ||
916 validate_expr(state, stmt->v.AsyncFunctionDef.returns, Load));
917 break;
918 case Pass_kind:
919 case Break_kind:
920 case Continue_kind:
921 ret = 1;
922 break;
923 // No default case so compiler emits warning for unhandled cases
924 }
925 if (ret < 0) {
926 PyErr_SetString(PyExc_SystemError, "unexpected statement");
927 ret = 0;
928 }
929 state->recursion_depth--;
930 return ret;
931 }
932
933 static int
validate_stmts(struct validator * state,asdl_stmt_seq * seq)934 validate_stmts(struct validator *state, asdl_stmt_seq *seq)
935 {
936 assert(!PyErr_Occurred());
937 for (Py_ssize_t i = 0; i < asdl_seq_LEN(seq); i++) {
938 stmt_ty stmt = asdl_seq_GET(seq, i);
939 if (stmt) {
940 if (!validate_stmt(state, stmt))
941 return 0;
942 }
943 else {
944 PyErr_SetString(PyExc_ValueError,
945 "None disallowed in statement list");
946 return 0;
947 }
948 }
949 return 1;
950 }
951
952 static int
validate_exprs(struct validator * state,asdl_expr_seq * exprs,expr_context_ty ctx,int null_ok)953 validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ctx, int null_ok)
954 {
955 assert(!PyErr_Occurred());
956 for (Py_ssize_t i = 0; i < asdl_seq_LEN(exprs); i++) {
957 expr_ty expr = asdl_seq_GET(exprs, i);
958 if (expr) {
959 if (!validate_expr(state, expr, ctx))
960 return 0;
961 }
962 else if (!null_ok) {
963 PyErr_SetString(PyExc_ValueError,
964 "None disallowed in expression list");
965 return 0;
966 }
967
968 }
969 return 1;
970 }
971
972 static int
validate_patterns(struct validator * state,asdl_pattern_seq * patterns,int star_ok)973 validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ok)
974 {
975 assert(!PyErr_Occurred());
976 for (Py_ssize_t i = 0; i < asdl_seq_LEN(patterns); i++) {
977 pattern_ty pattern = asdl_seq_GET(patterns, i);
978 if (!validate_pattern(state, pattern, star_ok)) {
979 return 0;
980 }
981 }
982 return 1;
983 }
984
985
986 /* See comments in symtable.c. */
987 #define COMPILER_STACK_FRAME_SCALE 3
988
989 int
_PyAST_Validate(mod_ty mod)990 _PyAST_Validate(mod_ty mod)
991 {
992 assert(!PyErr_Occurred());
993 int res = -1;
994 struct validator state;
995 PyThreadState *tstate;
996 int recursion_limit = Py_GetRecursionLimit();
997 int starting_recursion_depth;
998
999 /* Setup recursion depth check counters */
1000 tstate = _PyThreadState_GET();
1001 if (!tstate) {
1002 return 0;
1003 }
1004 /* Be careful here to prevent overflow. */
1005 int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
1006 starting_recursion_depth = (recursion_depth< INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1007 recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
1008 state.recursion_depth = starting_recursion_depth;
1009 state.recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1010 recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
1011
1012 switch (mod->kind) {
1013 case Module_kind:
1014 res = validate_stmts(&state, mod->v.Module.body);
1015 break;
1016 case Interactive_kind:
1017 res = validate_stmts(&state, mod->v.Interactive.body);
1018 break;
1019 case Expression_kind:
1020 res = validate_expr(&state, mod->v.Expression.body, Load);
1021 break;
1022 case FunctionType_kind:
1023 res = validate_exprs(&state, mod->v.FunctionType.argtypes, Load, /*null_ok=*/0) &&
1024 validate_expr(&state, mod->v.FunctionType.returns, Load);
1025 break;
1026 // No default case so compiler emits warning for unhandled cases
1027 }
1028
1029 if (res < 0) {
1030 PyErr_SetString(PyExc_SystemError, "impossible module node");
1031 return 0;
1032 }
1033
1034 /* Check that the recursion depth counting balanced correctly */
1035 if (res && state.recursion_depth != starting_recursion_depth) {
1036 PyErr_Format(PyExc_SystemError,
1037 "AST validator recursion depth mismatch (before=%d, after=%d)",
1038 starting_recursion_depth, state.recursion_depth);
1039 return 0;
1040 }
1041 return res;
1042 }
1043
1044 PyObject *
_PyAST_GetDocString(asdl_stmt_seq * body)1045 _PyAST_GetDocString(asdl_stmt_seq *body)
1046 {
1047 if (!asdl_seq_LEN(body)) {
1048 return NULL;
1049 }
1050 stmt_ty st = asdl_seq_GET(body, 0);
1051 if (st->kind != Expr_kind) {
1052 return NULL;
1053 }
1054 expr_ty e = st->v.Expr.value;
1055 if (e->kind == Constant_kind && PyUnicode_CheckExact(e->v.Constant.value)) {
1056 return e->v.Constant.value;
1057 }
1058 return NULL;
1059 }
1060