xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/sugared_value.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <memory>
3 #include <optional>
4 #include <string>
5 #include <utility>
6 
7 #include <ATen/core/symbol.h>
8 #include <caffe2/serialize/versions.h>
9 #include <torch/csrc/jit/api/module.h>
10 #include <torch/csrc/jit/frontend/error_report.h>
11 #include <torch/csrc/jit/frontend/schema_matching.h>
12 #include <torch/csrc/jit/frontend/versioned_symbols.h>
13 #include <torch/csrc/jit/ir/ir.h>
14 
15 namespace torch::jit {
16 
17 using SugaredValuePtr = std::shared_ptr<SugaredValue>;
18 
19 // The AST can contain nodes like `self`, `self.b` or `python_fn` that
20 // are not first-class values in the graph representation, but instead
21 // will be desugared based on how they are used in the AST.
22 
23 // SugaredValue is used to temporarily represent these values in a way
24 // that separates their behavior from the AST -> IR converter itself.
25 // This allows us to keep dependencies on python minimal.
26 
27 struct TORCH_API SugaredValue
28     : public std::enable_shared_from_this<SugaredValue> {
29   // what is this node? for error reporting (e.g. Module, python function)
30   virtual std::string kind() const = 0;
31 
32   // what can we do with this thing?
33   // use it as a value e.g.  `this + 4`
asValueSugaredValue34   virtual Value* asValue(const SourceRange& loc, GraphFunction& m) {
35     throw(ErrorReport(loc) << kind() << " cannot be used as a value");
36   }
37 
38   // select an attribute on it, e.g. `this.field`
attrSugaredValue39   virtual std::shared_ptr<SugaredValue> attr(
40       const SourceRange& loc,
41       GraphFunction& m,
42       const std::string& field) {
43     throw(ErrorReport(loc) << "attribute lookup is not defined on " << kind());
44   }
45 
hasAttrSugaredValue46   virtual bool hasAttr(
47       const SourceRange& loc,
48       GraphFunction& m,
49       const std::string& field) {
50     throw(ErrorReport(loc) << "attribute lookup is not defined on " << kind());
51   }
52 
53   // assign an attribute on it, e.g. `this.field = newValue`
setAttrSugaredValue54   virtual void setAttr(
55       const SourceRange& loc,
56       GraphFunction& m,
57       const std::string& field,
58       Value* newValue) {
59     throw(
60         ErrorReport(loc) << "attribute assignment is not defined on "
61                          << kind());
62   }
63 
64   // use it as a vector of values, e.g. a tuple of values as return value from
65   // a method invocation
66   virtual std::vector<std::shared_ptr<SugaredValue>> asTuple(
67       const SourceRange& loc,
68       GraphFunction& m,
69       const std::optional<size_t>& size_hint = {}) {
70     throw(ErrorReport(loc) << kind() << " cannot be used as a tuple");
71   }
72 
73   // TODO @wconstab refactor to use ModuleValue::asTuple instead of new API
asTupleValueSugaredValue74   virtual SugaredValuePtr asTupleValue(
75       const SourceRange& loc,
76       GraphFunction& m) {
77     throw(ErrorReport(loc) << kind() << " cannot be used as a tuplevalue");
78   }
79 
asTypeSugaredValue80   virtual std::vector<std::shared_ptr<SugaredValue>> asType(
81       const SourceRange& loc,
82       Method& m) {
83     throw(ErrorReport(loc) << kind() << " cannot be used as a type");
84   }
85 
86   // call it like a function, e.g. `outputs = this(inputs)`
callSugaredValue87   virtual std::shared_ptr<SugaredValue> call(
88       const SourceRange& loc,
89       GraphFunction& m,
90       // note: names for args will be 'argument 0', 'argument 1', etc..
91       at::ArrayRef<NamedValue> args,
92       at::ArrayRef<NamedValue> kwargs,
93       size_t n_binders) {
94     // n_binders is always set to the number of variables an expression is
95     // syntactically bound to:
96     //     a = foo() # 1 binder (note in this case the single binder might be a
97     //     tuple) a, * b = foo() # 1 binder a, b = foo() # 2 binders foo() # 0
98     //     binders
99     //
100     // In subexpressions, like bar() in foo(bar()), n_binders is always set to
101     // 1. n_binders is used as a hint to subexpressions to determine how many
102     // values they should return when that number is ambiguous statically. In
103     // particular it is currently used to decide how many tensors a call to a
104     // python function will return. It is only a hint, functions do not have to
105     // check that n_binders match the number of things they are returning, the
106     // assignment logic will do that anyway.
107 
108     throw(ErrorReport(loc) << "cannot call a " << kind());
109   }
110 
111   // This function is called when to convert a SugaredValue to its iterator.
112   // For example, when iterating through a Dict we iterate over its keys
iterSugaredValue113   virtual std::shared_ptr<SugaredValue> iter(
114       const SourceRange& loc,
115       GraphFunction& m) {
116     throw(ErrorReport(loc) << kind() << " cannot be used as an iterable");
117   }
118 
119   // If we are iterating over a Sugared Value and it returns a value from this
120   // function, then we emit an unrolled loop over the variable. This allows us
121   // to support containers of Heterogenous types, like Module Containers &
122   // Tuples
staticLenSugaredValue123   virtual std::optional<int64_t> staticLen() {
124     return std::nullopt;
125   }
126 
127   // When iterating over this SugaredValue, should we emit the for loop as an
128   // unrolled loop.
shouldEmitUnrolledSugaredValue129   bool shouldEmitUnrolled() {
130     return staticLen() != std::nullopt;
131   }
132 
133   // return length of this thing, if not then it can't be iterated.
134   // If it does not have a statically-determinable length, then it cannot
135   // be iterated over with a modulelist. If it does it must return a constant
136   // Value *
lenSugaredValue137   virtual Value* len(const SourceRange& loc, GraphFunction& m) {
138     throw(
139         ErrorReport(loc) << "'" << kind() << "'"
140                          << " object is not iterable");
141   }
142 
143   // expression for ith elemement for iterable value
144   virtual std::shared_ptr<SugaredValue> getitem(
145       const SourceRange& loc,
146       GraphFunction& m,
147       Value* idx,
148       TypePtr type_hint = nullptr) {
149     throw(
150         ErrorReport(loc) << "'" << kind() << "'"
151                          << " object is not subscriptable");
152   }
153 
154   virtual ~SugaredValue() = default;
155 };
156 
157 // most things in the environment are just simple value types
158 // and not special python syntax sugar types
159 struct TORCH_API SimpleValue : public SugaredValue {
SimpleValueSimpleValue160   SimpleValue(Value* value) : value_(value) {}
kindSimpleValue161   std::string kind() const override {
162     std::stringstream ss;
163     // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
164     ss << "value of type '" << value_->type()->annotation_str() << "'";
165     return ss.str();
166   }
asValueSimpleValue167   Value* asValue(const SourceRange& range, GraphFunction& m) override {
168     return value_;
169   }
170   std::vector<std::shared_ptr<SugaredValue>> asTuple(
171       const SourceRange& loc,
172       GraphFunction& m,
173       const std::optional<size_t>& size_hint = {}) override;
174   std::shared_ptr<SugaredValue> attr(
175       const SourceRange& loc,
176       GraphFunction& m,
177       const std::string& field) override;
178 
179   bool hasAttr(
180       const SourceRange& loc,
181       GraphFunction& m,
182       const std::string& field) override;
183 
184   void setAttr(
185       const SourceRange& loc,
186       GraphFunction& m,
187       const std::string& field,
188       Value* newValue) override;
189 
190   std::shared_ptr<SugaredValue> call(
191       const SourceRange& loc,
192       GraphFunction& m,
193       // note: names for args will be 'argument 0', 'argument 1', etc..
194       at::ArrayRef<NamedValue> args,
195       at::ArrayRef<NamedValue> kwargs,
196       size_t n_binders) override;
197 
198   std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
199       override;
200 
getValueSimpleValue201   Value* getValue() const {
202     return value_;
203   }
204 
205   Value* len(const SourceRange& loc, GraphFunction& m) override;
206   SugaredValuePtr getitem(
207       const SourceRange& loc,
208       GraphFunction& m,
209       Value* idx,
210       TypePtr type_hint = nullptr) override;
211 
212  private:
213   Value* value_;
214 };
215 
216 struct TORCH_API BuiltinFunction : public SugaredValue {
BuiltinFunctionBuiltinFunction217   BuiltinFunction(Symbol symbol, std::optional<NamedValue> self)
218       : symbol(symbol), self(std::move(self)) {}
219 
220   // The symbol of the function (e.g. `aten::relu`).
221   Symbol symbol;
222 
223   // if this is method, then this is the self argument.
224   std::optional<NamedValue> self;
kindBuiltinFunction225   std::string kind() const override {
226     return "builtin";
227   }
228   std::shared_ptr<SugaredValue> call(
229       const SourceRange& loc,
230       GraphFunction& m,
231       at::ArrayRef<NamedValue> args,
232       at::ArrayRef<NamedValue> kwargs,
233       size_t n_binders) override;
234 
235   // try to create this builtin but if it doesn't exist or the self argument
236   // cannot possibly match, then return nullptr. Use in situations where it is
237   // not clear if it is a valid builtin
238   static std::shared_ptr<BuiltinFunction> tryCreate(
239       Symbol symbol,
240       std::optional<NamedValue> self);
241 };
242 
243 struct TORCH_API SugaredTupleValue : public SugaredValue {
SugaredTupleValueSugaredTupleValue244   explicit SugaredTupleValue(std::vector<std::shared_ptr<SugaredValue>> tup)
245       : tup_(std::move(tup)){};
246 
247   std::vector<std::shared_ptr<SugaredValue>> asTuple(
248       const SourceRange& loc,
249       GraphFunction& m,
250       const std::optional<size_t>& size_hint = {}) override {
251     return tup_;
252   };
253 
asValueSugaredTupleValue254   Value* asValue(const SourceRange& loc, GraphFunction& m) override {
255     std::vector<Value*> vec;
256     vec.reserve(tup_.size());
257     for (const auto& sv : tup_) {
258       vec.push_back(sv->asValue(loc, m));
259     }
260     Graph& g = *m.graph();
261     return g.insertNode(g.createTuple(vec))->output();
262   }
263 
kindSugaredTupleValue264   std::string kind() const override {
265     return "Tuple";
266   }
267 
268   SugaredValuePtr getitem(
269       const SourceRange& loc,
270       GraphFunction& m,
271       Value* idx,
272       TypePtr type_hint = nullptr) override {
273     if (!(idx->type()->cast<IntType>() && toIValue(idx))) {
274       throw(
275           ErrorReport(loc)
276           << "Expected integer literal for index but got a variable or non-integer. "
277           << "ModuleList/Sequential indexing is only supported with integer literals. "
278           << "For example, 'i = 4; self.layers[i](x)' will fail because i is not a literal. "
279           << "Enumeration is supported, e.g. 'for index, v in enumerate(self): out = v(inp)'");
280     }
281     auto index = toIValue(idx)->toInt();
282     int64_t adj_index =
283         (index < 0) ? index + static_cast<int64_t>(tup_.size()) : index;
284     if (!(adj_index >= 0 && adj_index < static_cast<int64_t>(tup_.size()))) {
285       throw(
286           ErrorReport(loc) << "Index " << index << " out of range of length "
287                            << tup_.size());
288     }
289     return tup_.at(adj_index);
290   }
291 
292   // This function is called when a SugaredValue is used to convert a
293   // SugaredValue to its iterator. For example, when iterating through a Dict we
294   // iterate over its keys
iterSugaredTupleValue295   std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
296       override {
297     return shared_from_this();
298   };
299 
300   // Because this is used to contain SugaredValues of Heterogenous types,
301   // we define staticLen() so that when this is iterated over it is emitted
302   // as an unrolled loop.
staticLenSugaredTupleValue303   std::optional<int64_t> staticLen() override {
304     return static_cast<int64_t>(tup_.size());
305   }
306 
307   std::vector<std::shared_ptr<SugaredValue>> tup_;
308 };
309 
310 struct TORCH_API BuiltinModule : public SugaredValue {
311   BuiltinModule(std::string name, std::optional<int64_t> version = std::nullopt)
nameBuiltinModule312       : name(std::move(name)), version(version) {}
313 
kindBuiltinModule314   std::string kind() const override {
315     return "builtin module";
316   }
attrBuiltinModule317   std::shared_ptr<SugaredValue> attr(
318       const SourceRange& loc,
319       GraphFunction& m,
320       const std::string& field) override {
321     if (field == "autograd") {
322       // When refering torch.autograd, it is also considered to be a
323       // BuiltinModule and we will dispatch to the aten operators for the
324       // methods under its module.
325       return std::make_shared<BuiltinModule>("aten", version);
326     }
327 
328     auto sym = Symbol::fromQualString(name + "::" + field);
329     return std::make_shared<BuiltinFunction>(sym, std::nullopt);
330   }
331 
332  private:
333   std::string name;
334   // when we add operator versioning, emit this op as it exising at 'version'
335   // if not set, use the latest version
336   std::optional<int64_t> version;
337 };
338 
339 // Represents a class, analagous to `int` or `dict`. Instances of classes,
340 // like `1` or `{"foo": 5}`, are represented as SimpleValues
341 struct TORCH_API ClassValue : public SugaredValue {
ClassValueClassValue342   explicit ClassValue(ClassTypePtr type) : type_(std::move(type)) {}
343 
344   // Call the type's constructor, as in:
345   //    n = Foo(constructor_arg)
346   std::shared_ptr<SugaredValue> call(
347       const SourceRange& loc,
348       GraphFunction& m,
349       at::ArrayRef<NamedValue> args,
350       at::ArrayRef<NamedValue> kwargs,
351       size_t n_binders) override;
352 
353   std::shared_ptr<SugaredValue> attr(
354       const SourceRange& loc,
355       GraphFunction& m,
356       const std::string& field) override;
357 
kindClassValue358   std::string kind() const override {
359     return type_->str();
360   }
361 
362   ClassTypePtr type_;
363 };
364 
365 struct TORCH_API NamedTupleConstructor : public SugaredValue {
NamedTupleConstructorNamedTupleConstructor366   explicit NamedTupleConstructor(TupleTypePtr type) : type_(std::move(type)) {}
367 
368   std::shared_ptr<SugaredValue> call(
369       const SourceRange& loc,
370       GraphFunction& m,
371       at::ArrayRef<NamedValue> args,
372       at::ArrayRef<NamedValue> kwargs,
373       size_t n_binders) override;
374 
kindNamedTupleConstructor375   std::string kind() const override {
376     return type_->str();
377   }
378 
379   TupleTypePtr type_;
380 };
381 
382 struct FunctionValue : public SugaredValue {
FunctionValueFunctionValue383   FunctionValue(Function* callee) : callees_({callee}) {}
FunctionValueFunctionValue384   FunctionValue(const StrongFunctionPtr& p)
385       : callees_({p.function_}), cu_(p.cu_) {}
FunctionValueFunctionValue386   FunctionValue(const std::vector<StrongFunctionPtr>& callees) {
387     for (const StrongFunctionPtr& callee : callees) {
388       cu_ = cu_ ? cu_ : callee.cu_;
389       TORCH_INTERNAL_ASSERT(callee.cu_ == cu_);
390       callees_.push_back(callee.function_);
391     }
392   }
393 
kindFunctionValue394   std::string kind() const override {
395     return "function";
396   }
397 
callFunctionValue398   std::shared_ptr<SugaredValue> call(
399       const SourceRange& loc,
400       GraphFunction& f,
401       at::ArrayRef<NamedValue> args,
402       at::ArrayRef<NamedValue> kwargs,
403       size_t n_binders) override {
404     std::vector<const FunctionSchema*> schemas;
405     for (Function* callee : callees_) {
406       try {
407         callee->ensure_defined();
408       } catch (const RecursiveMethodCallError&) {
409         throw(
410             ErrorReport(loc)
411             << " function '" << callee->name() << "' is called recursively. "
412             << "Recursive calls are not supported");
413       }
414       schemas.push_back(&callee->getSchema());
415     }
416     auto match = matchSchemas(schemas, loc, *f.graph(), args, kwargs);
417     Value* output =
418         f.graph()->insertFunctionCall(callees_[match.first], match.second);
419     output->node()->setSourceRange(loc);
420     return std::make_shared<SimpleValue>(output);
421   }
422 
calleesFunctionValue423   const std::vector<Function*>& callees() {
424     return callees_;
425   }
426 
427  private:
428   std::vector<Function*> callees_;
429   // TODO holding this thing is creepy
430   std::shared_ptr<CompilationUnit> cu_;
431 };
432 
433 struct TORCH_API ClosureValue : public SugaredValue {
ClosureValueClosureValue434   ClosureValue(Value* value) : value_(value) {
435     TORCH_INTERNAL_ASSERT(value_->node()->kind() == prim::Closure);
436   }
kindClosureValue437   std::string kind() const override {
438     return "closure";
439   }
asValueClosureValue440   Value* asValue(const SourceRange& range, GraphFunction& m) override {
441     return value_;
442   }
443   Value* value_;
444 };
445 
446 // defines how a method obtained from a module/class/interface behaves in script
447 struct MethodValue : public SugaredValue {
MethodValueMethodValue448   MethodValue(Value* self, std::vector<std::string> method_names)
449       : self_(self), method_names_(std::move(method_names)) {}
MethodValueMethodValue450   MethodValue(Value* self, std::string method_name)
451       : MethodValue(self, std::vector<std::string>({std::move(method_name)})) {}
452 
kindMethodValue453   std::string kind() const override {
454     return "method";
455   }
456 
callMethodValue457   std::shared_ptr<SugaredValue> call(
458       const SourceRange& loc,
459       GraphFunction& f,
460       at::ArrayRef<NamedValue> args,
461       at::ArrayRef<NamedValue> kwargs,
462       size_t n_binders) override {
463     std::vector<NamedValue> argsWithSelf = {self_};
464     argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end());
465     std::vector<const FunctionSchema*> schemas;
466     for (const std::string& method_name : method_names_) {
467       if (auto class_type = self_->type()->cast<ClassType>()) {
468         Function& method = class_type->getMethod(method_name);
469         try {
470           method.ensure_defined();
471         } catch (const RecursiveMethodCallError&) {
472           throw(
473               ErrorReport(loc)
474               << " method '" << method.name() << "' is called recursively. "
475               << "Recursive calls are not supported");
476         }
477         schemas.push_back(&method.getSchema());
478       } else if (auto interface_type = self_->type()->cast<InterfaceType>()) {
479         schemas.push_back(interface_type->getMethod(method_name));
480       } else {
481         TORCH_INTERNAL_ASSERT(
482             false, "method constructed that is not a class or interface");
483       }
484     }
485     auto match = matchSchemas(schemas, loc, *f.graph(), argsWithSelf, kwargs);
486     Value* output =
487         f.graph()->insertMethodCall(method_names_[match.first], match.second);
488     output->node()->setSourceRange(loc);
489     return std::make_shared<SimpleValue>(output);
490   }
491 
492  private:
493   Value* self_;
494   std::vector<std::string> method_names_;
495 };
496 
497 struct TORCH_API PrintValue : public SugaredValue {
kindPrintValue498   std::string kind() const override {
499     return "print";
500   }
501   std::shared_ptr<SugaredValue> call(
502       const SourceRange& loc,
503       GraphFunction& m,
504       at::ArrayRef<NamedValue> args,
505       at::ArrayRef<NamedValue> kwargs,
506       size_t n_binders) override;
507 };
508 
509 // expressions like int(x)
510 // these are the same as call prim::Int or equivalent except it
511 // is a noop when the input is a subtype of 'type'
512 struct TORCH_API CastValue : public BuiltinFunction {
CastValueCastValue513   CastValue(TypePtr type, c10::Symbol method)
514       : BuiltinFunction(method, std::nullopt), type_(std::move(type)) {}
callCastValue515   std::shared_ptr<SugaredValue> call(
516       const SourceRange& loc,
517       GraphFunction& m,
518       at::ArrayRef<NamedValue> args,
519       at::ArrayRef<NamedValue> kwargs,
520       size_t n_binders) override {
521     if (args.size() == 1 && kwargs.empty()) {
522       auto len_op = std::make_shared<BuiltinFunction>(aten::len, std::nullopt);
523       auto gt_op = std::make_shared<BuiltinFunction>(aten::gt, std::nullopt);
524       auto zero = m.graph()->insertConstant(0);
525 
526       auto v = args[0].value(*m.graph());
527       if (v->type()->isSubtypeOf(*type_)) {
528         return std::make_shared<SimpleValue>(v);
529       } else if (
530           *type_ == *BoolType::get() &&
531           (v->type()->isSubtypeOf(*AnyListType::get()) ||
532            v->type()->isSubtypeOf(*StringType::get()) ||
533            v->type()->cast<DictType>())) {
534         auto len = len_op->call(loc, m, {v}, {}, 1);
535         return gt_op->call(loc, m, {len->asValue(loc, m), zero}, {}, 1);
536       }
537     }
538     return BuiltinFunction::call(loc, m, args, kwargs, n_binders);
539   }
540 
541  private:
542   TypePtr type_;
543 };
544 
545 struct TORCH_API TensorCastValue : public SugaredValue {
TensorCastValueTensorCastValue546   TensorCastValue(at::ScalarType type, NamedValue self)
547       : dtype_(type), self_(std::move(self)) {}
548 
kindTensorCastValue549   std::string kind() const override {
550     return "Cast";
551   }
552 
callTensorCastValue553   std::shared_ptr<SugaredValue> call(
554       const SourceRange& loc,
555       GraphFunction& m,
556       at::ArrayRef<NamedValue> args,
557       at::ArrayRef<NamedValue> kwargs,
558       size_t n_binders) override {
559     TORCH_INTERNAL_ASSERT(args.empty() && kwargs.empty());
560     Value* dtype_const = m.graph()->insertConstant(dtype_, loc);
561     std::vector<NamedValue> kwargs_{
562         self_, NamedValue(loc, "dtype", dtype_const)};
563     Value* casted_val = m.graph()->insert(
564         /*opname=*/Symbol::fromQualString("aten::to"),
565         /*args=*/args,
566         /*kwargs=*/kwargs_,
567         /*range=*/loc);
568     return std::make_shared<SimpleValue>(casted_val);
569   }
570 
571   at::ScalarType dtype_;
572   NamedValue self_;
573 };
574 
575 // builtins operators and functions that call a method if it exists
576 // on a class type, like 'len(x)' and 'x + y'
577 struct TORCH_API MagicMethod : public SugaredValue {
MagicMethodMagicMethod578   MagicMethod(std::string desugared_name, SugaredValuePtr base)
579       : base_value_(std::move(base)),
580         desugared_name_(std::move(desugared_name)) {}
581 
kindMagicMethod582   std::string kind() const override {
583     return desugared_name_;
584   }
585 
586   std::shared_ptr<SugaredValue> call(
587       const SourceRange& loc,
588       GraphFunction& m,
589       at::ArrayRef<NamedValue> args,
590       at::ArrayRef<NamedValue> kwargs,
591       size_t n_binders) override;
592 
593  private:
594   SugaredValuePtr base_value_;
595   std::string desugared_name_;
596 };
597 
598 // things that look like function applications, but
599 // perform non-standard evaluation are represented
600 // with SpecialFormValues, e.g.
601 //   isinstance(x, int)
602 //   fork(fn)
603 //   annotate(int, 3)
604 // The implementation of each value is handled by a case inside emitApplyExpr
605 struct TORCH_API SpecialFormValue : public SugaredValue {
SpecialFormValueSpecialFormValue606   SpecialFormValue(Symbol form) : form_(form) {}
kindSpecialFormValue607   std::string kind() const override {
608     return form_.toUnqualString();
609   }
formSpecialFormValue610   Symbol form() const {
611     return form_;
612   }
createSpecialFormValue613   static std::shared_ptr<SpecialFormValue> create(Symbol form) {
614     return std::make_shared<SpecialFormValue>(form);
615   }
616 
617  private:
618   Symbol form_;
619 };
620 
621 struct TORCH_API LegacyTensorConstructor : public SpecialFormValue {
LegacyTensorConstructorLegacyTensorConstructor622   LegacyTensorConstructor(Symbol form, at::ScalarType dtype, at::Device device)
623       : SpecialFormValue(form), device_(device), dtype_(dtype) {}
624 
createLegacyTensorConstructor625   static std::shared_ptr<LegacyTensorConstructor> create(
626       Symbol form,
627       at::ScalarType dtype,
628       at::Device device) {
629     return std::make_shared<LegacyTensorConstructor>(form, dtype, device);
630   }
dtypeLegacyTensorConstructor631   at::ScalarType dtype() const {
632     return dtype_;
633   }
634 
635  private:
636   at::Device device_;
637   at::ScalarType dtype_;
638 };
639 
640 // matched against for special handling of range expressions
641 struct TORCH_API RangeValue : SugaredValue {
642   RangeValue(
643       const SourceRange& loc,
644       GraphFunction& m,
645       std::vector<Value*> input,
646       std::optional<int64_t> static_len = std::nullopt);
647 
kindRangeValue648   std::string kind() const override {
649     return "range";
650   }
651   Value* len(const SourceRange& loc, GraphFunction& m) override;
652   SugaredValuePtr getitem(
653       const SourceRange& loc,
654       GraphFunction& m,
655       Value* idx,
656       TypePtr type_hint = nullptr) override;
657   std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
658       override;
659 
660   // When Range is instantiated via enumerate(iterable_with_static_len),
661   // then it takes the static length of the iterable
staticLenRangeValue662   std::optional<int64_t> staticLen() override {
663     return static_len_;
664   }
665 
666  private:
667   Value* start_{};
668   Value* end_{};
669   Value* step_{};
670   // a flag to determine if it's a simple range() call with only end_ from
671   // arguments If true, we will not insert length calculation and index
672   // derivation nodes to simplify the graph and enable more possible
673   // optimizations
674   bool has_only_end_{};
675   std::optional<int64_t> static_len_;
676 };
677 
678 // Specialized Tree structure to matched against for special handling
679 // of builtin functions iterables expressions like zip(), enumerate(), etc.
680 // zip and enumerate can be modeled as a tree of SimpleValue/RangeValue:
681 //    zip(x, y) ->  (x, y) with tuple assignment to each loop target
682 //    enumerate(x) -> (range(0, math.inf, 1), x)
683 // So a complicated expression like zip(a, enumerate(b), range(0, 100)) will be:
684 // (a, (range(0, math.inf, 1), b), range(0, 100))
685 // We use those base iterables to fill in the loop information like
686 // max_trip_count and set the value table for loop targets
687 // Iterables can contain lists of SugaredValues like ModuleLists. If it
688 // does, then we emit it unrolled and require that all values it contains
689 // have a statically-determinable length.
690 struct TORCH_API IterableTree : SugaredValue {
691   IterableTree() = default;
IterableTreeIterableTree692   IterableTree(
693       const SourceRange& range,
694       GraphFunction& m,
695       at::ArrayRef<SugaredValuePtr> children) {
696     for (const auto& child : children) {
697       addChild(range, m, child);
698     }
699   }
kindIterableTree700   std::string kind() const override {
701     return "iterabletree";
702   }
703 
iterIterableTree704   std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
705       override {
706     return shared_from_this();
707   }
708 
709   void addChild(
710       const SourceRange& range,
711       GraphFunction& m,
712       const SugaredValuePtr& iter_value);
713 
get_childrenIterableTree714   std::vector<SugaredValuePtr> get_children() {
715     return children_;
716   }
717 
718   // If this iterable contains a ModuleList or Tuple, then it will have a
719   // static length, and we will emit it as an unrolled for loop.
staticLenIterableTree720   std::optional<int64_t> staticLen() override {
721     return unroll_length_;
722   }
723 
724   // given a IterableTree node, get all the base iterables/leaves under the
725   // IterableTree node. This enables
726   // us to get all the basic SugaredValues that contains valid loop information
727   // with len() and getitem()
728   std::vector<SugaredValuePtr> get_base_iterables();
729 
730   Value* len(const SourceRange& loc, GraphFunction& m) override;
731   SugaredValuePtr getitem(
732       const SourceRange& loc,
733       GraphFunction& m,
734       Value* idx,
735       TypePtr type_hint = nullptr) override;
736 
737  private:
738   std::optional<int64_t> unroll_length_ = std::nullopt;
739   std::vector<SugaredValuePtr> children_;
740 };
741 
toValues(Graph & g,at::ArrayRef<NamedValue> nvs)742 static inline std::vector<Value*> toValues(
743     Graph& g,
744     at::ArrayRef<NamedValue> nvs) {
745   return fmap(nvs, [&](const NamedValue& v) { return v.value(g); });
746 }
747 
748 struct SimpleSelf : public Self {
SimpleSelfSimpleSelf749   explicit SimpleSelf(ClassTypePtr classType)
750       : Self(), classType_(std::move(classType)) {}
makeSugaredSimpleSelf751   std::shared_ptr<SugaredValue> makeSugared(Value* v) const override {
752     v->setType(classType_);
753     return std::make_shared<SimpleValue>(v);
754   }
getClassTypeSimpleSelf755   ClassTypePtr getClassType() const override {
756     return classType_;
757   }
758 
759  private:
760   ClassTypePtr classType_;
761 };
762 
763 // This is not a SimpleValue so it can not pass through the code paths that
764 // expect a SimpleValue as a sugared value.
765 struct TORCH_API ExceptionMessageValue : public SugaredValue {
766   explicit ExceptionMessageValue(
767       Value* value,
768       Value* qualified_class_name = nullptr)
value_ExceptionMessageValue769       : value_(value), qualified_class_name_(qualified_class_name) {}
770 
kindExceptionMessageValue771   std::string kind() const override {
772     return "exception message";
773   }
774 
getValueExceptionMessageValue775   Value* getValue() {
776     return value_;
777   }
778 
779   // qualified python class name
getQualifiedClassNameExceptionMessageValue780   Value* getQualifiedClassName() {
781     return qualified_class_name_;
782   }
783 
784  private:
785   Value* value_;
786   Value* qualified_class_name_;
787 };
788 
789 struct TORCH_API ExceptionValue : public SugaredValue {
ExceptionValueExceptionValue790   explicit ExceptionValue(std::string message) : message_(std::move(message)) {}
791 
kindExceptionValue792   std::string kind() const override {
793     return "exception";
794   }
795 
callExceptionValue796   std::shared_ptr<SugaredValue> call(
797       const SourceRange& loc,
798       GraphFunction& m,
799       at::ArrayRef<NamedValue> args,
800       at::ArrayRef<NamedValue> /*attributes*/,
801       size_t /*n_binders*/) override {
802     auto exception_message = insertConstant(*m.graph(), message_ + ": ", loc);
803     for (auto& input : args) {
804       auto input_str = input.value(*m.graph());
805       if (!input_str->type()->isSubtypeOf(*StringType::get())) {
806         input_str =
807             emitBuiltinCall(loc, *m.graph(), aten::str, {input_str}, {});
808       }
809       exception_message = emitBuiltinCall(
810           loc, *m.graph(), aten::add, {exception_message, input_str}, {});
811     }
812     return std::make_shared<ExceptionMessageValue>(exception_message);
813   }
814 
815   std::string message_;
816 };
817 
818 struct TORCH_API SugaredEnumClass : public SugaredValue {
SugaredEnumClassSugaredEnumClass819   explicit SugaredEnumClass(EnumTypePtr enum_type)
820       : enum_type_(std::move(enum_type)) {}
821 
kindSugaredEnumClass822   std::string kind() const override {
823     return "EnumClass";
824   }
825 
826   SugaredValuePtr attr(
827       const SourceRange& loc,
828       GraphFunction& m,
829       const std::string& field) override;
830 
831   SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override;
832 
833  private:
834   EnumTypePtr enum_type_;
835 };
836 
837 struct TORCH_API SliceValue : public SugaredValue {
SliceValueSliceValue838   explicit SliceValue(Value* start, Value* stop, Value* step)
839       : start_(start), stop_(stop), step_(step) {}
840 
kindSliceValue841   std::string kind() const override {
842     return "Python slice value";
843   }
844 
startSliceValue845   Value* start() {
846     return start_;
847   };
stopSliceValue848   Value* stop() {
849     return stop_;
850   };
stepSliceValue851   Value* step() {
852     return step_;
853   };
854 
855  private:
856   Value* start_;
857   Value* stop_;
858   Value* step_;
859 };
860 
861 } // namespace torch::jit
862