1 #pragma once
2 #include <memory>
3 #include <optional>
4 #include <string>
5 #include <utility>
6
7 #include <ATen/core/symbol.h>
8 #include <caffe2/serialize/versions.h>
9 #include <torch/csrc/jit/api/module.h>
10 #include <torch/csrc/jit/frontend/error_report.h>
11 #include <torch/csrc/jit/frontend/schema_matching.h>
12 #include <torch/csrc/jit/frontend/versioned_symbols.h>
13 #include <torch/csrc/jit/ir/ir.h>
14
15 namespace torch::jit {
16
17 using SugaredValuePtr = std::shared_ptr<SugaredValue>;
18
19 // The AST can contain nodes like `self`, `self.b` or `python_fn` that
20 // are not first-class values in the graph representation, but instead
21 // will be desugared based on how they are used in the AST.
22
23 // SugaredValue is used to temporarily represent these values in a way
24 // that separates their behavior from the AST -> IR converter itself.
25 // This allows us to keep dependencies on python minimal.
26
27 struct TORCH_API SugaredValue
28 : public std::enable_shared_from_this<SugaredValue> {
29 // what is this node? for error reporting (e.g. Module, python function)
30 virtual std::string kind() const = 0;
31
32 // what can we do with this thing?
33 // use it as a value e.g. `this + 4`
asValueSugaredValue34 virtual Value* asValue(const SourceRange& loc, GraphFunction& m) {
35 throw(ErrorReport(loc) << kind() << " cannot be used as a value");
36 }
37
38 // select an attribute on it, e.g. `this.field`
attrSugaredValue39 virtual std::shared_ptr<SugaredValue> attr(
40 const SourceRange& loc,
41 GraphFunction& m,
42 const std::string& field) {
43 throw(ErrorReport(loc) << "attribute lookup is not defined on " << kind());
44 }
45
hasAttrSugaredValue46 virtual bool hasAttr(
47 const SourceRange& loc,
48 GraphFunction& m,
49 const std::string& field) {
50 throw(ErrorReport(loc) << "attribute lookup is not defined on " << kind());
51 }
52
53 // assign an attribute on it, e.g. `this.field = newValue`
setAttrSugaredValue54 virtual void setAttr(
55 const SourceRange& loc,
56 GraphFunction& m,
57 const std::string& field,
58 Value* newValue) {
59 throw(
60 ErrorReport(loc) << "attribute assignment is not defined on "
61 << kind());
62 }
63
64 // use it as a vector of values, e.g. a tuple of values as return value from
65 // a method invocation
66 virtual std::vector<std::shared_ptr<SugaredValue>> asTuple(
67 const SourceRange& loc,
68 GraphFunction& m,
69 const std::optional<size_t>& size_hint = {}) {
70 throw(ErrorReport(loc) << kind() << " cannot be used as a tuple");
71 }
72
73 // TODO @wconstab refactor to use ModuleValue::asTuple instead of new API
asTupleValueSugaredValue74 virtual SugaredValuePtr asTupleValue(
75 const SourceRange& loc,
76 GraphFunction& m) {
77 throw(ErrorReport(loc) << kind() << " cannot be used as a tuplevalue");
78 }
79
asTypeSugaredValue80 virtual std::vector<std::shared_ptr<SugaredValue>> asType(
81 const SourceRange& loc,
82 Method& m) {
83 throw(ErrorReport(loc) << kind() << " cannot be used as a type");
84 }
85
86 // call it like a function, e.g. `outputs = this(inputs)`
callSugaredValue87 virtual std::shared_ptr<SugaredValue> call(
88 const SourceRange& loc,
89 GraphFunction& m,
90 // note: names for args will be 'argument 0', 'argument 1', etc..
91 at::ArrayRef<NamedValue> args,
92 at::ArrayRef<NamedValue> kwargs,
93 size_t n_binders) {
94 // n_binders is always set to the number of variables an expression is
95 // syntactically bound to:
96 // a = foo() # 1 binder (note in this case the single binder might be a
97 // tuple) a, * b = foo() # 1 binder a, b = foo() # 2 binders foo() # 0
98 // binders
99 //
100 // In subexpressions, like bar() in foo(bar()), n_binders is always set to
101 // 1. n_binders is used as a hint to subexpressions to determine how many
102 // values they should return when that number is ambiguous statically. In
103 // particular it is currently used to decide how many tensors a call to a
104 // python function will return. It is only a hint, functions do not have to
105 // check that n_binders match the number of things they are returning, the
106 // assignment logic will do that anyway.
107
108 throw(ErrorReport(loc) << "cannot call a " << kind());
109 }
110
111 // This function is called when to convert a SugaredValue to its iterator.
112 // For example, when iterating through a Dict we iterate over its keys
iterSugaredValue113 virtual std::shared_ptr<SugaredValue> iter(
114 const SourceRange& loc,
115 GraphFunction& m) {
116 throw(ErrorReport(loc) << kind() << " cannot be used as an iterable");
117 }
118
119 // If we are iterating over a Sugared Value and it returns a value from this
120 // function, then we emit an unrolled loop over the variable. This allows us
121 // to support containers of Heterogenous types, like Module Containers &
122 // Tuples
staticLenSugaredValue123 virtual std::optional<int64_t> staticLen() {
124 return std::nullopt;
125 }
126
127 // When iterating over this SugaredValue, should we emit the for loop as an
128 // unrolled loop.
shouldEmitUnrolledSugaredValue129 bool shouldEmitUnrolled() {
130 return staticLen() != std::nullopt;
131 }
132
133 // return length of this thing, if not then it can't be iterated.
134 // If it does not have a statically-determinable length, then it cannot
135 // be iterated over with a modulelist. If it does it must return a constant
136 // Value *
lenSugaredValue137 virtual Value* len(const SourceRange& loc, GraphFunction& m) {
138 throw(
139 ErrorReport(loc) << "'" << kind() << "'"
140 << " object is not iterable");
141 }
142
143 // expression for ith elemement for iterable value
144 virtual std::shared_ptr<SugaredValue> getitem(
145 const SourceRange& loc,
146 GraphFunction& m,
147 Value* idx,
148 TypePtr type_hint = nullptr) {
149 throw(
150 ErrorReport(loc) << "'" << kind() << "'"
151 << " object is not subscriptable");
152 }
153
154 virtual ~SugaredValue() = default;
155 };
156
157 // most things in the environment are just simple value types
158 // and not special python syntax sugar types
159 struct TORCH_API SimpleValue : public SugaredValue {
SimpleValueSimpleValue160 SimpleValue(Value* value) : value_(value) {}
kindSimpleValue161 std::string kind() const override {
162 std::stringstream ss;
163 // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
164 ss << "value of type '" << value_->type()->annotation_str() << "'";
165 return ss.str();
166 }
asValueSimpleValue167 Value* asValue(const SourceRange& range, GraphFunction& m) override {
168 return value_;
169 }
170 std::vector<std::shared_ptr<SugaredValue>> asTuple(
171 const SourceRange& loc,
172 GraphFunction& m,
173 const std::optional<size_t>& size_hint = {}) override;
174 std::shared_ptr<SugaredValue> attr(
175 const SourceRange& loc,
176 GraphFunction& m,
177 const std::string& field) override;
178
179 bool hasAttr(
180 const SourceRange& loc,
181 GraphFunction& m,
182 const std::string& field) override;
183
184 void setAttr(
185 const SourceRange& loc,
186 GraphFunction& m,
187 const std::string& field,
188 Value* newValue) override;
189
190 std::shared_ptr<SugaredValue> call(
191 const SourceRange& loc,
192 GraphFunction& m,
193 // note: names for args will be 'argument 0', 'argument 1', etc..
194 at::ArrayRef<NamedValue> args,
195 at::ArrayRef<NamedValue> kwargs,
196 size_t n_binders) override;
197
198 std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
199 override;
200
getValueSimpleValue201 Value* getValue() const {
202 return value_;
203 }
204
205 Value* len(const SourceRange& loc, GraphFunction& m) override;
206 SugaredValuePtr getitem(
207 const SourceRange& loc,
208 GraphFunction& m,
209 Value* idx,
210 TypePtr type_hint = nullptr) override;
211
212 private:
213 Value* value_;
214 };
215
216 struct TORCH_API BuiltinFunction : public SugaredValue {
BuiltinFunctionBuiltinFunction217 BuiltinFunction(Symbol symbol, std::optional<NamedValue> self)
218 : symbol(symbol), self(std::move(self)) {}
219
220 // The symbol of the function (e.g. `aten::relu`).
221 Symbol symbol;
222
223 // if this is method, then this is the self argument.
224 std::optional<NamedValue> self;
kindBuiltinFunction225 std::string kind() const override {
226 return "builtin";
227 }
228 std::shared_ptr<SugaredValue> call(
229 const SourceRange& loc,
230 GraphFunction& m,
231 at::ArrayRef<NamedValue> args,
232 at::ArrayRef<NamedValue> kwargs,
233 size_t n_binders) override;
234
235 // try to create this builtin but if it doesn't exist or the self argument
236 // cannot possibly match, then return nullptr. Use in situations where it is
237 // not clear if it is a valid builtin
238 static std::shared_ptr<BuiltinFunction> tryCreate(
239 Symbol symbol,
240 std::optional<NamedValue> self);
241 };
242
243 struct TORCH_API SugaredTupleValue : public SugaredValue {
SugaredTupleValueSugaredTupleValue244 explicit SugaredTupleValue(std::vector<std::shared_ptr<SugaredValue>> tup)
245 : tup_(std::move(tup)){};
246
247 std::vector<std::shared_ptr<SugaredValue>> asTuple(
248 const SourceRange& loc,
249 GraphFunction& m,
250 const std::optional<size_t>& size_hint = {}) override {
251 return tup_;
252 };
253
asValueSugaredTupleValue254 Value* asValue(const SourceRange& loc, GraphFunction& m) override {
255 std::vector<Value*> vec;
256 vec.reserve(tup_.size());
257 for (const auto& sv : tup_) {
258 vec.push_back(sv->asValue(loc, m));
259 }
260 Graph& g = *m.graph();
261 return g.insertNode(g.createTuple(vec))->output();
262 }
263
kindSugaredTupleValue264 std::string kind() const override {
265 return "Tuple";
266 }
267
268 SugaredValuePtr getitem(
269 const SourceRange& loc,
270 GraphFunction& m,
271 Value* idx,
272 TypePtr type_hint = nullptr) override {
273 if (!(idx->type()->cast<IntType>() && toIValue(idx))) {
274 throw(
275 ErrorReport(loc)
276 << "Expected integer literal for index but got a variable or non-integer. "
277 << "ModuleList/Sequential indexing is only supported with integer literals. "
278 << "For example, 'i = 4; self.layers[i](x)' will fail because i is not a literal. "
279 << "Enumeration is supported, e.g. 'for index, v in enumerate(self): out = v(inp)'");
280 }
281 auto index = toIValue(idx)->toInt();
282 int64_t adj_index =
283 (index < 0) ? index + static_cast<int64_t>(tup_.size()) : index;
284 if (!(adj_index >= 0 && adj_index < static_cast<int64_t>(tup_.size()))) {
285 throw(
286 ErrorReport(loc) << "Index " << index << " out of range of length "
287 << tup_.size());
288 }
289 return tup_.at(adj_index);
290 }
291
292 // This function is called when a SugaredValue is used to convert a
293 // SugaredValue to its iterator. For example, when iterating through a Dict we
294 // iterate over its keys
iterSugaredTupleValue295 std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
296 override {
297 return shared_from_this();
298 };
299
300 // Because this is used to contain SugaredValues of Heterogenous types,
301 // we define staticLen() so that when this is iterated over it is emitted
302 // as an unrolled loop.
staticLenSugaredTupleValue303 std::optional<int64_t> staticLen() override {
304 return static_cast<int64_t>(tup_.size());
305 }
306
307 std::vector<std::shared_ptr<SugaredValue>> tup_;
308 };
309
310 struct TORCH_API BuiltinModule : public SugaredValue {
311 BuiltinModule(std::string name, std::optional<int64_t> version = std::nullopt)
nameBuiltinModule312 : name(std::move(name)), version(version) {}
313
kindBuiltinModule314 std::string kind() const override {
315 return "builtin module";
316 }
attrBuiltinModule317 std::shared_ptr<SugaredValue> attr(
318 const SourceRange& loc,
319 GraphFunction& m,
320 const std::string& field) override {
321 if (field == "autograd") {
322 // When refering torch.autograd, it is also considered to be a
323 // BuiltinModule and we will dispatch to the aten operators for the
324 // methods under its module.
325 return std::make_shared<BuiltinModule>("aten", version);
326 }
327
328 auto sym = Symbol::fromQualString(name + "::" + field);
329 return std::make_shared<BuiltinFunction>(sym, std::nullopt);
330 }
331
332 private:
333 std::string name;
334 // when we add operator versioning, emit this op as it exising at 'version'
335 // if not set, use the latest version
336 std::optional<int64_t> version;
337 };
338
339 // Represents a class, analagous to `int` or `dict`. Instances of classes,
340 // like `1` or `{"foo": 5}`, are represented as SimpleValues
341 struct TORCH_API ClassValue : public SugaredValue {
ClassValueClassValue342 explicit ClassValue(ClassTypePtr type) : type_(std::move(type)) {}
343
344 // Call the type's constructor, as in:
345 // n = Foo(constructor_arg)
346 std::shared_ptr<SugaredValue> call(
347 const SourceRange& loc,
348 GraphFunction& m,
349 at::ArrayRef<NamedValue> args,
350 at::ArrayRef<NamedValue> kwargs,
351 size_t n_binders) override;
352
353 std::shared_ptr<SugaredValue> attr(
354 const SourceRange& loc,
355 GraphFunction& m,
356 const std::string& field) override;
357
kindClassValue358 std::string kind() const override {
359 return type_->str();
360 }
361
362 ClassTypePtr type_;
363 };
364
365 struct TORCH_API NamedTupleConstructor : public SugaredValue {
NamedTupleConstructorNamedTupleConstructor366 explicit NamedTupleConstructor(TupleTypePtr type) : type_(std::move(type)) {}
367
368 std::shared_ptr<SugaredValue> call(
369 const SourceRange& loc,
370 GraphFunction& m,
371 at::ArrayRef<NamedValue> args,
372 at::ArrayRef<NamedValue> kwargs,
373 size_t n_binders) override;
374
kindNamedTupleConstructor375 std::string kind() const override {
376 return type_->str();
377 }
378
379 TupleTypePtr type_;
380 };
381
382 struct FunctionValue : public SugaredValue {
FunctionValueFunctionValue383 FunctionValue(Function* callee) : callees_({callee}) {}
FunctionValueFunctionValue384 FunctionValue(const StrongFunctionPtr& p)
385 : callees_({p.function_}), cu_(p.cu_) {}
FunctionValueFunctionValue386 FunctionValue(const std::vector<StrongFunctionPtr>& callees) {
387 for (const StrongFunctionPtr& callee : callees) {
388 cu_ = cu_ ? cu_ : callee.cu_;
389 TORCH_INTERNAL_ASSERT(callee.cu_ == cu_);
390 callees_.push_back(callee.function_);
391 }
392 }
393
kindFunctionValue394 std::string kind() const override {
395 return "function";
396 }
397
callFunctionValue398 std::shared_ptr<SugaredValue> call(
399 const SourceRange& loc,
400 GraphFunction& f,
401 at::ArrayRef<NamedValue> args,
402 at::ArrayRef<NamedValue> kwargs,
403 size_t n_binders) override {
404 std::vector<const FunctionSchema*> schemas;
405 for (Function* callee : callees_) {
406 try {
407 callee->ensure_defined();
408 } catch (const RecursiveMethodCallError&) {
409 throw(
410 ErrorReport(loc)
411 << " function '" << callee->name() << "' is called recursively. "
412 << "Recursive calls are not supported");
413 }
414 schemas.push_back(&callee->getSchema());
415 }
416 auto match = matchSchemas(schemas, loc, *f.graph(), args, kwargs);
417 Value* output =
418 f.graph()->insertFunctionCall(callees_[match.first], match.second);
419 output->node()->setSourceRange(loc);
420 return std::make_shared<SimpleValue>(output);
421 }
422
calleesFunctionValue423 const std::vector<Function*>& callees() {
424 return callees_;
425 }
426
427 private:
428 std::vector<Function*> callees_;
429 // TODO holding this thing is creepy
430 std::shared_ptr<CompilationUnit> cu_;
431 };
432
433 struct TORCH_API ClosureValue : public SugaredValue {
ClosureValueClosureValue434 ClosureValue(Value* value) : value_(value) {
435 TORCH_INTERNAL_ASSERT(value_->node()->kind() == prim::Closure);
436 }
kindClosureValue437 std::string kind() const override {
438 return "closure";
439 }
asValueClosureValue440 Value* asValue(const SourceRange& range, GraphFunction& m) override {
441 return value_;
442 }
443 Value* value_;
444 };
445
446 // defines how a method obtained from a module/class/interface behaves in script
447 struct MethodValue : public SugaredValue {
MethodValueMethodValue448 MethodValue(Value* self, std::vector<std::string> method_names)
449 : self_(self), method_names_(std::move(method_names)) {}
MethodValueMethodValue450 MethodValue(Value* self, std::string method_name)
451 : MethodValue(self, std::vector<std::string>({std::move(method_name)})) {}
452
kindMethodValue453 std::string kind() const override {
454 return "method";
455 }
456
callMethodValue457 std::shared_ptr<SugaredValue> call(
458 const SourceRange& loc,
459 GraphFunction& f,
460 at::ArrayRef<NamedValue> args,
461 at::ArrayRef<NamedValue> kwargs,
462 size_t n_binders) override {
463 std::vector<NamedValue> argsWithSelf = {self_};
464 argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end());
465 std::vector<const FunctionSchema*> schemas;
466 for (const std::string& method_name : method_names_) {
467 if (auto class_type = self_->type()->cast<ClassType>()) {
468 Function& method = class_type->getMethod(method_name);
469 try {
470 method.ensure_defined();
471 } catch (const RecursiveMethodCallError&) {
472 throw(
473 ErrorReport(loc)
474 << " method '" << method.name() << "' is called recursively. "
475 << "Recursive calls are not supported");
476 }
477 schemas.push_back(&method.getSchema());
478 } else if (auto interface_type = self_->type()->cast<InterfaceType>()) {
479 schemas.push_back(interface_type->getMethod(method_name));
480 } else {
481 TORCH_INTERNAL_ASSERT(
482 false, "method constructed that is not a class or interface");
483 }
484 }
485 auto match = matchSchemas(schemas, loc, *f.graph(), argsWithSelf, kwargs);
486 Value* output =
487 f.graph()->insertMethodCall(method_names_[match.first], match.second);
488 output->node()->setSourceRange(loc);
489 return std::make_shared<SimpleValue>(output);
490 }
491
492 private:
493 Value* self_;
494 std::vector<std::string> method_names_;
495 };
496
497 struct TORCH_API PrintValue : public SugaredValue {
kindPrintValue498 std::string kind() const override {
499 return "print";
500 }
501 std::shared_ptr<SugaredValue> call(
502 const SourceRange& loc,
503 GraphFunction& m,
504 at::ArrayRef<NamedValue> args,
505 at::ArrayRef<NamedValue> kwargs,
506 size_t n_binders) override;
507 };
508
509 // expressions like int(x)
510 // these are the same as call prim::Int or equivalent except it
511 // is a noop when the input is a subtype of 'type'
512 struct TORCH_API CastValue : public BuiltinFunction {
CastValueCastValue513 CastValue(TypePtr type, c10::Symbol method)
514 : BuiltinFunction(method, std::nullopt), type_(std::move(type)) {}
callCastValue515 std::shared_ptr<SugaredValue> call(
516 const SourceRange& loc,
517 GraphFunction& m,
518 at::ArrayRef<NamedValue> args,
519 at::ArrayRef<NamedValue> kwargs,
520 size_t n_binders) override {
521 if (args.size() == 1 && kwargs.empty()) {
522 auto len_op = std::make_shared<BuiltinFunction>(aten::len, std::nullopt);
523 auto gt_op = std::make_shared<BuiltinFunction>(aten::gt, std::nullopt);
524 auto zero = m.graph()->insertConstant(0);
525
526 auto v = args[0].value(*m.graph());
527 if (v->type()->isSubtypeOf(*type_)) {
528 return std::make_shared<SimpleValue>(v);
529 } else if (
530 *type_ == *BoolType::get() &&
531 (v->type()->isSubtypeOf(*AnyListType::get()) ||
532 v->type()->isSubtypeOf(*StringType::get()) ||
533 v->type()->cast<DictType>())) {
534 auto len = len_op->call(loc, m, {v}, {}, 1);
535 return gt_op->call(loc, m, {len->asValue(loc, m), zero}, {}, 1);
536 }
537 }
538 return BuiltinFunction::call(loc, m, args, kwargs, n_binders);
539 }
540
541 private:
542 TypePtr type_;
543 };
544
545 struct TORCH_API TensorCastValue : public SugaredValue {
TensorCastValueTensorCastValue546 TensorCastValue(at::ScalarType type, NamedValue self)
547 : dtype_(type), self_(std::move(self)) {}
548
kindTensorCastValue549 std::string kind() const override {
550 return "Cast";
551 }
552
callTensorCastValue553 std::shared_ptr<SugaredValue> call(
554 const SourceRange& loc,
555 GraphFunction& m,
556 at::ArrayRef<NamedValue> args,
557 at::ArrayRef<NamedValue> kwargs,
558 size_t n_binders) override {
559 TORCH_INTERNAL_ASSERT(args.empty() && kwargs.empty());
560 Value* dtype_const = m.graph()->insertConstant(dtype_, loc);
561 std::vector<NamedValue> kwargs_{
562 self_, NamedValue(loc, "dtype", dtype_const)};
563 Value* casted_val = m.graph()->insert(
564 /*opname=*/Symbol::fromQualString("aten::to"),
565 /*args=*/args,
566 /*kwargs=*/kwargs_,
567 /*range=*/loc);
568 return std::make_shared<SimpleValue>(casted_val);
569 }
570
571 at::ScalarType dtype_;
572 NamedValue self_;
573 };
574
575 // builtins operators and functions that call a method if it exists
576 // on a class type, like 'len(x)' and 'x + y'
577 struct TORCH_API MagicMethod : public SugaredValue {
MagicMethodMagicMethod578 MagicMethod(std::string desugared_name, SugaredValuePtr base)
579 : base_value_(std::move(base)),
580 desugared_name_(std::move(desugared_name)) {}
581
kindMagicMethod582 std::string kind() const override {
583 return desugared_name_;
584 }
585
586 std::shared_ptr<SugaredValue> call(
587 const SourceRange& loc,
588 GraphFunction& m,
589 at::ArrayRef<NamedValue> args,
590 at::ArrayRef<NamedValue> kwargs,
591 size_t n_binders) override;
592
593 private:
594 SugaredValuePtr base_value_;
595 std::string desugared_name_;
596 };
597
598 // things that look like function applications, but
599 // perform non-standard evaluation are represented
600 // with SpecialFormValues, e.g.
601 // isinstance(x, int)
602 // fork(fn)
603 // annotate(int, 3)
604 // The implementation of each value is handled by a case inside emitApplyExpr
605 struct TORCH_API SpecialFormValue : public SugaredValue {
SpecialFormValueSpecialFormValue606 SpecialFormValue(Symbol form) : form_(form) {}
kindSpecialFormValue607 std::string kind() const override {
608 return form_.toUnqualString();
609 }
formSpecialFormValue610 Symbol form() const {
611 return form_;
612 }
createSpecialFormValue613 static std::shared_ptr<SpecialFormValue> create(Symbol form) {
614 return std::make_shared<SpecialFormValue>(form);
615 }
616
617 private:
618 Symbol form_;
619 };
620
621 struct TORCH_API LegacyTensorConstructor : public SpecialFormValue {
LegacyTensorConstructorLegacyTensorConstructor622 LegacyTensorConstructor(Symbol form, at::ScalarType dtype, at::Device device)
623 : SpecialFormValue(form), device_(device), dtype_(dtype) {}
624
createLegacyTensorConstructor625 static std::shared_ptr<LegacyTensorConstructor> create(
626 Symbol form,
627 at::ScalarType dtype,
628 at::Device device) {
629 return std::make_shared<LegacyTensorConstructor>(form, dtype, device);
630 }
dtypeLegacyTensorConstructor631 at::ScalarType dtype() const {
632 return dtype_;
633 }
634
635 private:
636 at::Device device_;
637 at::ScalarType dtype_;
638 };
639
640 // matched against for special handling of range expressions
641 struct TORCH_API RangeValue : SugaredValue {
642 RangeValue(
643 const SourceRange& loc,
644 GraphFunction& m,
645 std::vector<Value*> input,
646 std::optional<int64_t> static_len = std::nullopt);
647
kindRangeValue648 std::string kind() const override {
649 return "range";
650 }
651 Value* len(const SourceRange& loc, GraphFunction& m) override;
652 SugaredValuePtr getitem(
653 const SourceRange& loc,
654 GraphFunction& m,
655 Value* idx,
656 TypePtr type_hint = nullptr) override;
657 std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
658 override;
659
660 // When Range is instantiated via enumerate(iterable_with_static_len),
661 // then it takes the static length of the iterable
staticLenRangeValue662 std::optional<int64_t> staticLen() override {
663 return static_len_;
664 }
665
666 private:
667 Value* start_{};
668 Value* end_{};
669 Value* step_{};
670 // a flag to determine if it's a simple range() call with only end_ from
671 // arguments If true, we will not insert length calculation and index
672 // derivation nodes to simplify the graph and enable more possible
673 // optimizations
674 bool has_only_end_{};
675 std::optional<int64_t> static_len_;
676 };
677
678 // Specialized Tree structure to matched against for special handling
679 // of builtin functions iterables expressions like zip(), enumerate(), etc.
680 // zip and enumerate can be modeled as a tree of SimpleValue/RangeValue:
681 // zip(x, y) -> (x, y) with tuple assignment to each loop target
682 // enumerate(x) -> (range(0, math.inf, 1), x)
683 // So a complicated expression like zip(a, enumerate(b), range(0, 100)) will be:
684 // (a, (range(0, math.inf, 1), b), range(0, 100))
685 // We use those base iterables to fill in the loop information like
686 // max_trip_count and set the value table for loop targets
687 // Iterables can contain lists of SugaredValues like ModuleLists. If it
688 // does, then we emit it unrolled and require that all values it contains
689 // have a statically-determinable length.
690 struct TORCH_API IterableTree : SugaredValue {
691 IterableTree() = default;
IterableTreeIterableTree692 IterableTree(
693 const SourceRange& range,
694 GraphFunction& m,
695 at::ArrayRef<SugaredValuePtr> children) {
696 for (const auto& child : children) {
697 addChild(range, m, child);
698 }
699 }
kindIterableTree700 std::string kind() const override {
701 return "iterabletree";
702 }
703
iterIterableTree704 std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
705 override {
706 return shared_from_this();
707 }
708
709 void addChild(
710 const SourceRange& range,
711 GraphFunction& m,
712 const SugaredValuePtr& iter_value);
713
get_childrenIterableTree714 std::vector<SugaredValuePtr> get_children() {
715 return children_;
716 }
717
718 // If this iterable contains a ModuleList or Tuple, then it will have a
719 // static length, and we will emit it as an unrolled for loop.
staticLenIterableTree720 std::optional<int64_t> staticLen() override {
721 return unroll_length_;
722 }
723
724 // given a IterableTree node, get all the base iterables/leaves under the
725 // IterableTree node. This enables
726 // us to get all the basic SugaredValues that contains valid loop information
727 // with len() and getitem()
728 std::vector<SugaredValuePtr> get_base_iterables();
729
730 Value* len(const SourceRange& loc, GraphFunction& m) override;
731 SugaredValuePtr getitem(
732 const SourceRange& loc,
733 GraphFunction& m,
734 Value* idx,
735 TypePtr type_hint = nullptr) override;
736
737 private:
738 std::optional<int64_t> unroll_length_ = std::nullopt;
739 std::vector<SugaredValuePtr> children_;
740 };
741
toValues(Graph & g,at::ArrayRef<NamedValue> nvs)742 static inline std::vector<Value*> toValues(
743 Graph& g,
744 at::ArrayRef<NamedValue> nvs) {
745 return fmap(nvs, [&](const NamedValue& v) { return v.value(g); });
746 }
747
748 struct SimpleSelf : public Self {
SimpleSelfSimpleSelf749 explicit SimpleSelf(ClassTypePtr classType)
750 : Self(), classType_(std::move(classType)) {}
makeSugaredSimpleSelf751 std::shared_ptr<SugaredValue> makeSugared(Value* v) const override {
752 v->setType(classType_);
753 return std::make_shared<SimpleValue>(v);
754 }
getClassTypeSimpleSelf755 ClassTypePtr getClassType() const override {
756 return classType_;
757 }
758
759 private:
760 ClassTypePtr classType_;
761 };
762
763 // This is not a SimpleValue so it can not pass through the code paths that
764 // expect a SimpleValue as a sugared value.
765 struct TORCH_API ExceptionMessageValue : public SugaredValue {
766 explicit ExceptionMessageValue(
767 Value* value,
768 Value* qualified_class_name = nullptr)
value_ExceptionMessageValue769 : value_(value), qualified_class_name_(qualified_class_name) {}
770
kindExceptionMessageValue771 std::string kind() const override {
772 return "exception message";
773 }
774
getValueExceptionMessageValue775 Value* getValue() {
776 return value_;
777 }
778
779 // qualified python class name
getQualifiedClassNameExceptionMessageValue780 Value* getQualifiedClassName() {
781 return qualified_class_name_;
782 }
783
784 private:
785 Value* value_;
786 Value* qualified_class_name_;
787 };
788
789 struct TORCH_API ExceptionValue : public SugaredValue {
ExceptionValueExceptionValue790 explicit ExceptionValue(std::string message) : message_(std::move(message)) {}
791
kindExceptionValue792 std::string kind() const override {
793 return "exception";
794 }
795
callExceptionValue796 std::shared_ptr<SugaredValue> call(
797 const SourceRange& loc,
798 GraphFunction& m,
799 at::ArrayRef<NamedValue> args,
800 at::ArrayRef<NamedValue> /*attributes*/,
801 size_t /*n_binders*/) override {
802 auto exception_message = insertConstant(*m.graph(), message_ + ": ", loc);
803 for (auto& input : args) {
804 auto input_str = input.value(*m.graph());
805 if (!input_str->type()->isSubtypeOf(*StringType::get())) {
806 input_str =
807 emitBuiltinCall(loc, *m.graph(), aten::str, {input_str}, {});
808 }
809 exception_message = emitBuiltinCall(
810 loc, *m.graph(), aten::add, {exception_message, input_str}, {});
811 }
812 return std::make_shared<ExceptionMessageValue>(exception_message);
813 }
814
815 std::string message_;
816 };
817
818 struct TORCH_API SugaredEnumClass : public SugaredValue {
SugaredEnumClassSugaredEnumClass819 explicit SugaredEnumClass(EnumTypePtr enum_type)
820 : enum_type_(std::move(enum_type)) {}
821
kindSugaredEnumClass822 std::string kind() const override {
823 return "EnumClass";
824 }
825
826 SugaredValuePtr attr(
827 const SourceRange& loc,
828 GraphFunction& m,
829 const std::string& field) override;
830
831 SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override;
832
833 private:
834 EnumTypePtr enum_type_;
835 };
836
837 struct TORCH_API SliceValue : public SugaredValue {
SliceValueSliceValue838 explicit SliceValue(Value* start, Value* stop, Value* step)
839 : start_(start), stop_(stop), step_(step) {}
840
kindSliceValue841 std::string kind() const override {
842 return "Python slice value";
843 }
844
startSliceValue845 Value* start() {
846 return start_;
847 };
stopSliceValue848 Value* stop() {
849 return stop_;
850 };
stepSliceValue851 Value* step() {
852 return step_;
853 };
854
855 private:
856 Value* start_;
857 Value* stop_;
858 Value* step_;
859 };
860
861 } // namespace torch::jit
862