xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/ir_emitter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/ir_emitter.h>
2 #include <torch/csrc/jit/frontend/tree_views.h>
3 
4 #include <c10/util/Exception.h>
5 #include <c10/util/StringUtil.h>
6 #include <c10/util/irange.h>
7 #include <caffe2/serialize/versions.h>
8 #include <torch/csrc/jit/api/function_impl.h>
9 #include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
10 #include <torch/csrc/jit/frontend/convert_to_ssa.h>
11 #include <torch/csrc/jit/frontend/lexer.h>
12 #include <torch/csrc/jit/frontend/parser.h>
13 #include <torch/csrc/jit/frontend/schema_matching.h>
14 #include <torch/csrc/jit/frontend/script_type_parser.h>
15 #include <torch/csrc/jit/ir/ir.h>
16 #include <torch/csrc/jit/passes/annotate_warns.h>
17 #include <torch/csrc/jit/passes/canonicalize.h>
18 #include <torch/csrc/jit/passes/constant_pooling.h>
19 #include <torch/csrc/jit/passes/constant_propagation.h>
20 #include <torch/csrc/jit/passes/dead_code_elimination.h>
21 #include <torch/csrc/jit/passes/inline_forked_closures.h>
22 #include <torch/csrc/jit/passes/inliner.h>
23 #include <torch/csrc/jit/passes/lift_closures.h>
24 #include <torch/csrc/jit/passes/lower_tuples.h>
25 #include <torch/csrc/jit/passes/normalize_ops.h>
26 #include <torch/csrc/jit/passes/replacement_of_old_operators.h>
27 #include <torch/csrc/jit/runtime/graph_iterator.h>
28 #include <torch/csrc/jit/runtime/interpreter.h>
29 #include <torch/csrc/jit/runtime/operator.h>
30 #include <torch/csrc/jit/runtime/slice_indices_adjust.h>
31 #include <torch/csrc/jit/testing/hooks_for_testing.h>
32 
33 #include <torch/csrc/jit/ir/constants.h>
34 
35 #include <c10/util/hash.h>
36 #include <optional>
37 
38 #include <ATen/core/interned_strings.h>
39 #include <ATen/core/jit_type.h>
40 #include <torch/csrc/jit/frontend/error_report.h>
41 #include <climits>
42 #include <set>
43 #include <stack>
44 
45 namespace {
reportSourceLocation(size_t file_size)46 bool reportSourceLocation(size_t file_size) {
47   if (file_size < 512ull * 1024) {
48     return true;
49   }
50   const char* enable_env =
51       std::getenv("PYTORCH_JIT_ENABLE_LARGE_SOURCE_LOCATION");
52   bool flag = true;
53   if (enable_env == nullptr || std::strcmp(enable_env, "0") == 0 ||
54       std::strcmp(enable_env, "FALSE") == 0 ||
55       std::strcmp(enable_env, "false") == 0) {
56     flag = false;
57   }
58   return flag;
59 }
60 } // namespace
61 
62 namespace torch::jit {
63 
64 using FunctionTable = std::unordered_map<std::string, Function&>;
65 using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
66 using TypeTable = std::unordered_map<std::string, TypePtr>;
67 using AttributeMap = std::unordered_map<std::string, Const>;
68 using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
69 
70 struct Refinement {
Refinementtorch::jit::Refinement71   Refinement(std::string identifier, TypePtr type)
72       : identifier_(std::move(identifier)), type_(std::move(type)) {}
identifiertorch::jit::Refinement73   const std::string& identifier() const {
74     return identifier_;
75   }
typetorch::jit::Refinement76   TypePtr type() const {
77     return type_;
78   }
79 
80  private:
81   std::string identifier_;
82   TypePtr type_;
83 };
84 
85 struct RefinementSet {
86   // When a comparison like x is None is made, we associate type refinements
87   // with its true value and its false value. If a boolean that has refinements
88   // associated with it is used in a conditional of an if statement, the true
89   // and false refinements are inserted into the corresponding blocks
90   using Refinements = std::vector<Refinement>;
91 
RefinementSettorch::jit::RefinementSet92   RefinementSet(Refinements true_refinements, Refinements false_refinements)
93       : true_refinements_(std::move(true_refinements)),
94         false_refinements_(std::move(false_refinements)) {}
RefinementSettorch::jit::RefinementSet95   RefinementSet(Refinement single) : RefinementSet({std::move(single)}, {}) {}
RefinementSettorch::jit::RefinementSet96   RefinementSet(Refinement single_true, Refinement single_false)
97       : RefinementSet(
98             Refinements({std::move(single_true)}),
99             Refinements({std::move(single_false)})) {}
100   RefinementSet() = default; // empty
Andtorch::jit::RefinementSet101   RefinementSet And(const RefinementSet& rhs) const {
102     // if the result of an AND is true, both a & b had to be true,
103     // so we take the union of a.true_refinements and b.true_refinements.
104     // if the result is false, either a or b could have been false,
105     // so we take their intersection.
106     return RefinementSet(
107         unionSet(true_refinements_, rhs.true_refinements_),
108         intersectSet(false_refinements_, rhs.false_refinements_));
109   }
Ortorch::jit::RefinementSet110   RefinementSet Or(const RefinementSet& rhs) const {
111     // if the result of an OR is true, either a & b could have been true,
112     // so we take the intersection of a.true_refinements & b.true_refinements.
113     // if the result is false, both a and b had to be false,
114     // so we take their union.
115     return RefinementSet(
116         intersectSet(true_refinements_, rhs.true_refinements_),
117         unionSet(false_refinements_, rhs.false_refinements_));
118   }
119 
Nottorch::jit::RefinementSet120   RefinementSet Not() const {
121     return RefinementSet(false_refinements_, true_refinements_);
122   }
activeRefinementstorch::jit::RefinementSet123   const std::vector<Refinement> activeRefinements() const {
124     return true_refinements_;
125   }
126 
127  private:
sameVartorch::jit::RefinementSet128   static bool sameVar(const Refinement& a, const Refinement& b) {
129     return a.identifier() == b.identifier();
130   }
unionSettorch::jit::RefinementSet131   static Refinements unionSet(const Refinements& a, const Refinements& b) {
132     Refinements result = a;
133     for (const Refinement& r : b) {
134       auto it =
135           std::find_if(result.begin(), result.end(), [&](const Refinement& e) {
136             return e.identifier() == r.identifier();
137           });
138       if (it == result.end()) {
139         result.push_back(r);
140       } else if (*it->type() != *r.type()) {
141         // we only keep refinements when they exactly match one
142         // refinement type, for instance, we do not attempt to refine:
143         // isinstance(x, float) and isinstance(x, int)
144         result.erase(it);
145       }
146     }
147     return result;
148   }
intersectSettorch::jit::RefinementSet149   static Refinements intersectSet(const Refinements& a, const Refinements& b) {
150     Refinements result;
151     for (const Refinement& r : a) {
152       auto it = std::find_if(b.begin(), b.end(), [&](const Refinement& e) {
153         return e.identifier() == r.identifier();
154       });
155       if (it != b.end() && r.type() == it->type()) {
156         result.push_back(r);
157       }
158     }
159     return result;
160   }
161 
162   Refinements true_refinements_;
163   Refinements false_refinements_;
164 };
165 
166 struct CondValue {
CondValuetorch::jit::CondValue167   CondValue(
168       Value* value,
169       RefinementSet refinements,
170       std::optional<bool> static_if)
171       : value_(value),
172         refinements_(std::move(refinements)),
173         static_if_(static_if) {}
CondValuetorch::jit::CondValue174   CondValue(
175       Graph& g,
176       const SourceRange& loc,
177       bool static_value,
178       RefinementSet refinements)
179       : value_(g.insertConstant(static_value, loc)),
180         refinements_(std::move(refinements)),
181         static_if_(static_value) {}
valuetorch::jit::CondValue182   Value* value() const {
183     return value_;
184   }
refinementstorch::jit::CondValue185   const RefinementSet& refinements() const {
186     return refinements_;
187   }
staticIftorch::jit::CondValue188   std::optional<bool> staticIf() const {
189     return static_if_;
190   }
191 
192  private:
193   Value* value_;
194   RefinementSet refinements_;
195   std::optional<bool>
196       static_if_; // certain expression cause us to emit a static if statement
197                   // this value is present if this is the case.
198                   // this is not equivalent to value_ being a constant
199                   // it is possible for value_ to be constant but for
200                   // the expression that produced it to not trigger the
201                   // static if behavior. e.g. use of a variable assigned
202                   // to a constant
203 };
204 
205 enum NoneStatus { ALWAYS, MAYBE, NEVER };
canBeNone(Value * v)206 static NoneStatus canBeNone(Value* v) {
207   if (v->node()->mustBeNone()) {
208     return ALWAYS;
209   }
210   if (v->type()->kind() == OptionalType::Kind ||
211       (v->type()->kind() == UnionType::Kind &&
212        v->type()->expect<UnionType>()->canHoldType(*NoneType::get()))) {
213     return MAYBE;
214   }
215   return NEVER;
216 }
217 
asSimple(const SugaredValuePtr & value)218 static Value* asSimple(const SugaredValuePtr& value) {
219   if (SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
220     return sv->getValue();
221   }
222   return nullptr;
223 }
224 
makeMagic(const std::string & name,const SugaredValuePtr & base)225 static std::shared_ptr<MagicMethod> makeMagic(
226     const std::string& name,
227     const SugaredValuePtr& base) {
228   return std::make_shared<MagicMethod>(name, base);
229 }
230 
231 // Auxiliary data structure for desugaring variable binding into our always
232 // explicitly scoped language as we descend down nested control structures in
233 // the frontend (which themselves don't introduce scopes)
234 //
235 // The Environment keeps track of two tables, one for values which are not first
236 // class and a type table for values which are. When a first class value
237 // is set in the environment, we emit a prim::Store which sets the
238 // name of the variable to appropriate type, and when a first-class value is
239 // referenced we emit a prim::Load that generates a value of the appropriate
240 // type.
241 //
242 // a = 1
243 // print(a)
244 // becomes:
245 // = prim::Store[name="a"](%a.1)
246 // %a : int = prim::Load[name="a"]()
247 // prim::Print(%a)
248 
249 struct Environment {
Environmenttorch::jit::Environment250   Environment(
251       GraphFunction& method,
252       ResolverPtr resolver,
253       Block* b,
254       std::shared_ptr<Environment> next = nullptr)
255       : method(method),
256         resolver(std::move(resolver)),
257         b(b),
258         next(std::move(next)) {}
259 
260   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
261   GraphFunction& method;
262   ResolverPtr resolver;
263   std::unordered_map<std::string, std::function<std::string()>> error_messages;
264   Block* b;
265 
266   std::shared_ptr<Environment> next;
267 
268   // set type error in the lowest environment. if the variable is used after an
269   // error has been set, then we will use the more informative error message
setVariableTypeErrortorch::jit::Environment270   void setVariableTypeError(
271       const std::string& name,
272       std::function<std::string()> msg) {
273     auto runner = this;
274     while (runner->next) {
275       runner = runner->next.get();
276     }
277     runner->error_messages[name] = std::move(msg);
278   }
279 
280   // see if type error has been set for a variable
findVariableTypeErrortorch::jit::Environment281   std::optional<std::string> findVariableTypeError(const std::string& name) {
282     auto runner = this;
283     while (runner->next) {
284       runner = runner->next.get();
285     }
286     auto msg = runner->error_messages.find(name);
287     if (msg != runner->error_messages.end()) {
288       return msg->second();
289     } else {
290       return std::nullopt;
291     }
292   }
293 
insertLoadtorch::jit::Environment294   SugaredValuePtr insertLoad(const std::string& name, const TypePtr& type) {
295     auto g = b->owningGraph();
296     auto load = g->insertNode(g->createLoad(name, type));
297     if (meaningfulName(name)) {
298       load->output()->setDebugName(name);
299     }
300     return std::make_shared<SimpleValue>(load->output());
301   }
302 
303   // note: type is not always the same as v->type(), e.g.
304   // type: Optional[Tensor]
305   // v->type(): Tensor
insertStoretorch::jit::Environment306   void insertStore(
307       const std::string& name,
308       const SourceRange& loc,
309       Value* v,
310       TypePtr type) {
311     auto g = b->owningGraph();
312     g->insertNode(g->createStore(name, v))->setSourceRange(loc);
313     type_table[name] = std::move(type);
314   }
315 
findInThisFrametorch::jit::Environment316   SugaredValuePtr findInThisFrame(const std::string& name) {
317     auto it = value_table.find(name);
318     if (it != value_table.end()) {
319       return it->second;
320     }
321     auto it2 = type_table.find(name);
322     if (it2 != type_table.end()) {
323       return insertLoad(name, it2->second);
324     }
325     return nullptr;
326   }
327 
findInParentFrametorch::jit::Environment328   SugaredValuePtr findInParentFrame(const std::string& name) {
329     return next ? next->findInAnyFrame(name) : nullptr;
330   }
331 
setTypetorch::jit::Environment332   void setType(const std::string& name, TypePtr type) {
333     type_table[name] = std::move(type);
334   }
335 
findInAnyFrametorch::jit::Environment336   SugaredValuePtr findInAnyFrame(const std::string& name) {
337     for (auto runner = this; runner; runner = runner->next.get()) {
338       if (auto r = runner->findInThisFrame(name)) {
339         return r;
340       }
341     }
342     return nullptr;
343   }
344 
blocktorch::jit::Environment345   Block* block() {
346     return b;
347   }
348 
setVartorch::jit::Environment349   void setVar(const SourceRange& loc, const std::string& name, Value* value) {
350     setSugaredVar(
351         loc,
352         name,
353         std::make_shared<SimpleValue>(value),
354         /*annotated_type=*/nullptr);
355   }
356 
setSugaredVartorch::jit::Environment357   void setSugaredVar(
358       const SourceRange& loc,
359       const std::string& name,
360       SugaredValuePtr value,
361       const TypePtr& annotated_type) {
362     Value* as_simple_value = asSimple(value);
363     if (as_simple_value && !as_simple_value->hasDebugName() &&
364         meaningfulName(name) &&
365         // note: if the value wasn't defined in this block, we might be giving a
366         // name only used inside this block to a value outside of this. this is
367         // not normally helpful for debugging and causes import/export jitter.
368         as_simple_value->node()->owningBlock() == block()) {
369       as_simple_value->setDebugName(name);
370     }
371     // prevent re-assignment involving any sugared values
372     // any reassignment like:
373     // a = ...
374     // while ...
375     //   a = ..
376     // requires 'a' to be first-class in the graph since its value depends on
377     // control flow
378     if (auto parent = findInParentFrame(name)) {
379       if (annotated_type) {
380         throw(
381             ErrorReport(loc)
382             << "Attempting to declare and annotate the type of variable '"
383             << name << "' but it is already defined in an outer block");
384       }
385       if (!as_simple_value) {
386         throw(
387             ErrorReport(loc)
388             << "Cannot re-assign '" << name << "' to a value of type "
389             << value->kind() << " because " << name
390             << " is not a first-class value.  Only reassignments to first-class values are allowed");
391       }
392       Value* simple_parent = asSimple(parent);
393       if (!simple_parent) {
394         throw(
395             ErrorReport(loc)
396             << "Cannot re-assign '" << name << "' because it has type "
397             << value->kind() << " and " << name
398             << " is not a first-class value.  Only reassignments to first-class values are allowed");
399       }
400 
401       auto parent_type = unshapedType(simple_parent->type());
402       as_simple_value = tryConvertToType(
403           loc,
404           *b->owningGraph(),
405           parent_type,
406           as_simple_value,
407           /*allow_conversions=*/true);
408       std::stringstream why_not;
409       if (!as_simple_value->type()->isSubtypeOfExt(*parent_type, &why_not)) {
410         auto error = ErrorReport(loc);
411         error << "Variable '" << name << "' previously had type "
412               << simple_parent->type()->repr_str()
413               << " but is now being assigned to a value of type "
414               << as_simple_value->type()->repr_str();
415 
416         // Special-cased error msg if we're trying to assign to a tensor list.
417         if (simple_parent->type()->kind() == TypeKind::ListType &&
418             as_simple_value->type()->kind() == TypeKind::ListType) {
419           error << "\nEmpty lists default to List[Tensor]. Add a variable "
420                    "annotation to the assignment to create an empty list "
421                    "of another type (torch.jit.annotate(List[T, []]) where T "
422                    "is the type of elements in the list for Python 2)";
423         }
424         error << "\n" << why_not.str();
425         throw ErrorReport(error);
426       }
427     }
428     if (as_simple_value) {
429       if (annotated_type &&
430           !as_simple_value->type()->isSubtypeOf(*annotated_type)) {
431         throw(
432             ErrorReport(loc)
433             << "Variable '" << name << "' is annotated with type "
434             << annotated_type->repr_str()
435             << " but is being assigned to a value of type "
436             << as_simple_value->type()->repr_str());
437       }
438       auto value_store_type =
439           annotated_type ? annotated_type : as_simple_value->type();
440       insertStore(name, loc, as_simple_value, value_store_type);
441     } else {
442       value_table[name] = std::move(value);
443     }
444   }
445 
getSugaredVartorch::jit::Environment446   SugaredValuePtr getSugaredVar(const Ident& ident, bool required = true) {
447     return getSugaredVar(ident.name(), ident.range());
448   }
getVartorch::jit::Environment449   Value* getVar(const Ident& ident) {
450     return getSugaredVar(ident)->asValue(ident.range(), method);
451   }
452 
throwVarNotFoundErrortorch::jit::Environment453   void throwVarNotFoundError(
454       const std::string& ident,
455       const SourceRange& range) {
456     // check if this value was not emitted in an if statement because of a
457     // type mismatch. if it was, then we print a more informative error msg
458     if (auto msg = findVariableTypeError(ident)) {
459       throw(ErrorReport(range) << *msg << "and was used here");
460     }
461     throw(ErrorReport(range) << "undefined value " << ident);
462   }
463 
getSugaredVartorch::jit::Environment464   SugaredValuePtr getSugaredVar(
465       const std::string& ident,
466       const SourceRange& range,
467       bool required = true) {
468     auto retval = findInAnyFrame(ident);
469 
470     if (!retval) {
471       static std::unordered_map<std::string, SugaredValuePtr> globals = {
472           {"print", std::make_shared<PrintValue>()},
473           {"tuple", SpecialFormValue::create(prim::TupleConstruct)},
474           {"float",
475            makeMagic(
476                "__float__",
477                std::make_shared<CastValue>(FloatType::get(), aten::Float))},
478           {"complex",
479            makeMagic(
480                "__complex__",
481                std::make_shared<CastValue>(ComplexType::get(), aten::Complex))},
482           {"int",
483            makeMagic(
484                "__int__",
485                std::make_shared<CastValue>(IntType::get(), aten::Int))},
486           {"bool",
487            makeMagic(
488                "__bool__",
489                std::make_shared<CastValue>(BoolType::get(), aten::Bool))},
490           {"str",
491            makeMagic(
492                "__str__",
493                std::make_shared<CastValue>(StringType::get(), aten::str))},
494           {"getattr", SpecialFormValue::create(prim::GetAttr)},
495           {"hasattr", SpecialFormValue::create(prim::HasAttr)},
496           {"isinstance", SpecialFormValue::create(prim::isinstance)},
497           // todo(zach): remove when we can correctly export torch.full via ONNX
498           // or we have implicit conversion that can convert numbers to tensors
499           {"_to_tensor",
500            std::make_shared<CastValue>(TensorType::get(), prim::NumToTensor)},
501           {"len",
502            makeMagic(
503                "__len__",
504                std::make_shared<BuiltinFunction>(aten::len, std::nullopt))},
505           {"hex",
506            makeMagic(
507                "__hex__",
508                std::make_shared<BuiltinFunction>(aten::hex, std::nullopt))},
509           {"oct",
510            makeMagic(
511                "__oct__",
512                std::make_shared<BuiltinFunction>(aten::oct, std::nullopt))},
513           {"round",
514            makeMagic(
515                "__round__",
516                std::make_shared<BuiltinFunction>(aten::round, std::nullopt))},
517           {"hash", std::make_shared<BuiltinFunction>(aten::hash, std::nullopt)},
518           {"id", std::make_shared<BuiltinFunction>(prim::id, std::nullopt)},
519           {"min", std::make_shared<BuiltinFunction>(prim::min, std::nullopt)},
520           {"max", std::make_shared<BuiltinFunction>(prim::max, std::nullopt)},
521           {"abs", std::make_shared<BuiltinFunction>(prim::abs, std::nullopt)},
522           {"all", std::make_shared<BuiltinFunction>(aten::all, std::nullopt)},
523           {"any", std::make_shared<BuiltinFunction>(aten::any, std::nullopt)},
524           {"divmod",
525            std::make_shared<BuiltinFunction>(aten::divmod, std::nullopt)},
526           {"sum", std::make_shared<BuiltinFunction>(aten::sum, std::nullopt)},
527           {"list", SpecialFormValue::create(prim::list)},
528           {"dict", SpecialFormValue::create(prim::dict)},
529           {"ord", std::make_shared<BuiltinFunction>(aten::ord, std::nullopt)},
530           {"chr", std::make_shared<BuiltinFunction>(aten::chr, std::nullopt)},
531           {"bin", std::make_shared<BuiltinFunction>(aten::bin, std::nullopt)},
532           {"pow", std::make_shared<BuiltinFunction>(aten::pow, std::nullopt)},
533           {"range", SpecialFormValue::create(prim::range)},
534           {"zip", SpecialFormValue::create(prim::zip)},
535           {"enumerate", SpecialFormValue::create(prim::enumerate)},
536           {"rangelist",
537            std::make_shared<BuiltinFunction>(prim::rangelist, std::nullopt)},
538           {"sorted",
539            std::make_shared<BuiltinFunction>(aten::sorted, std::nullopt)},
540           // Only AssertionError is bound so that we can use it from emitAssert,
541           // all other exceptions should be resolved at the Python level
542           {"AssertionError",
543            std::make_shared<ExceptionValue>("AssertionError")},
544       };
545       auto it = globals.find(ident);
546       if (it != globals.end()) {
547         retval = it->second;
548       }
549     }
550 
551     if (!retval) {
552       if (auto type = resolver->resolveType(ident, range)) {
553         if (auto tuple_type = type->cast<TupleType>()) {
554           retval = std::make_shared<NamedTupleConstructor>(tuple_type);
555         }
556       }
557     }
558 
559     if (!retval) {
560       retval = resolver->resolveValue(ident, method, range);
561     }
562 
563     if (!retval) {
564       if (auto type = resolver->resolveType(ident, range)) {
565         if (auto class_type = type->cast<ClassType>()) {
566           retval = std::make_shared<ClassValue>(class_type);
567         }
568       }
569     }
570 
571     if (!retval && required) {
572       throwVarNotFoundError(ident, range);
573     }
574 
575     return retval;
576   }
577 
getVartorch::jit::Environment578   Value* getVar(const std::string& ident, const SourceRange& range) {
579     return getSugaredVar(ident, range)->asValue(range, method);
580   }
581 
removeVartorch::jit::Environment582   void removeVar(const Ident& ident, bool check_if_removed = false) {
583     bool removed = false;
584 
585     for (auto runner = this; runner; runner = runner->next.get()) {
586       auto a = runner->value_table.erase(ident.name());
587       auto b = runner->type_table.erase(ident.name());
588       removed = a || b;
589     }
590 
591     if (check_if_removed && !removed) {
592       throwVarNotFoundError(ident.name(), ident.range());
593     }
594   }
595 
definedVariablestorch::jit::Environment596   std::vector<std::string> definedVariables() {
597     std::vector<std::string> result;
598     for (auto& kv : type_table) {
599       result.push_back(kv.first);
600     }
601     return result;
602   }
603 
604  private:
605   TypeTable type_table;
606   ValueTable value_table;
607 };
608 
609 template <class T, class Hash>
materializeConstant(T val,Graph & graph,const SourceRange & r,std::unordered_map<T,Value *,Hash> & map)610 static Value* materializeConstant(
611     T val,
612     Graph& graph,
613     const SourceRange& r,
614     std::unordered_map<T, Value*, Hash>& map) {
615   auto existing_constant = map.find(val);
616   if (existing_constant != map.end()) {
617     return existing_constant->second;
618   }
619 
620   WithInsertPoint guard(graph.block()->nodes().front());
621   auto new_constant = graph.insertConstant(val, r);
622   map[val] = new_constant;
623 
624   return new_constant;
625 }
626 
isSupportedListElementType(const TypePtr & type)627 inline bool isSupportedListElementType(const TypePtr& type) {
628   return type->isSubtypeOf(*TensorType::get()) ||
629       type->isSubtypeOf(*NumberType::get());
630 }
631 
632 // Information for each def being emitted.
633 // Defs can be nested to support closures so we need a stack of this information
634 // Currently records information about the functions return type.
635 struct DefContext {
636   TypePtr declared_return_type_; // nullptr if not annotated
637   TypePtr merged_return_type_; // nullptr if a Return has not been seen yet
638 };
639 
640 enum class LoopStatus { NOT_IN_LOOP, IN_LOOP, IN_UNROLLED_LOOP };
641 
642 struct WithLoopStatus {
WithLoopStatustorch::jit::WithLoopStatus643   WithLoopStatus(LoopStatus* prev, LoopStatus new_status)
644       : prev_ptr_(prev), prev_value_(*prev) {
645     *prev = new_status;
646   }
~WithLoopStatustorch::jit::WithLoopStatus647   ~WithLoopStatus() {
648     *prev_ptr_ = prev_value_;
649   }
650 
651  private:
652   LoopStatus* prev_ptr_;
653   LoopStatus prev_value_;
654 };
655 
656 struct to_ir {
to_irtorch::jit::to_ir657   to_ir(
658       const Def& def,
659       ResolverPtr resolver_,
660       const Self* self,
661       GraphFunction& method) // method being constructed
662       : method(method),
663         graph(method.graph()),
664         resolver(std::move(resolver_)),
665         typeParser_(resolver),
666         environment_stack(nullptr) {
667     AT_ASSERT(resolver);
668     pushFrame(graph->block(), /*starts_def=*/true);
669 
670     // Type annotations exclude explicitly typing the "self" parameter, so in
671     // the case that this is a method with self we expect one fewer parameter
672     // annotation than the number of parameters this Def takes.
673     if (self && def.decl().params().empty()) {
674       throw(
675           ErrorReport(def.decl().params().range())
676           << "methods must have a self argument");
677     }
678     method.setSchema(emitDef(def, self, graph->block()));
679 
680     // At this point, we might have received a graph that is compiled with
681     // old operator schemas that might not exist in the system anymore.
682     // Therefore, we replace such ops with its' valid upgrader.
683     ReplaceOldOperatorsWithUpgraders(graph);
684 
685     // NB ORDERING: SSA conversion has to occur before
686     // lifting of closures and forks, this way closures are converted
687     // to SSA while part of their original graph, and closures are ready to
688     // be inlined into forked closures
689     ConvertToSSA(graph);
690 
691     // convert loops with an iter and body condition specified to
692     // python-recognize while loops. we do this so they can be exported,
693     // and run the pass early to avoid jitter. Like conversion to SSA,
694     // it only needs to run once.
695     CanonicalizeModifiedLoops(graph);
696 
697     // Convert Ops to a Normalized Form
698     NormalizeOps(graph);
699 
700     runCleanupPasses(graph);
701   }
702 
703  private:
704   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
705   GraphFunction& method;
706   std::shared_ptr<Graph> graph;
707   ResolverPtr resolver;
708   std::unordered_map<int64_t, Value*, std::hash<int64_t>> integral_constants;
709   std::unordered_map<double, Value*, std::hash<double>> fp_constants;
710   std::unordered_map<
711       c10::complex<double>,
712       Value*,
713       c10::hash<c10::complex<double>>>
714       complex_constants;
715   std::unordered_set<Block*> exit_blocks;
716   ScriptTypeParser typeParser_;
717   LoopStatus loop_status_ = LoopStatus::NOT_IN_LOOP;
718 
719   // Singly-linked list of environments. This top element contains a member
720   // `next` that points to the most immediate enclosing scope's value.
721   std::shared_ptr<Environment> environment_stack;
722   std::vector<DefContext> def_stack_;
723   size_t temp_name_count_ = 0;
createTempNametorch::jit::to_ir724   std::string createTempName(const std::string& prefix) {
725     return prefix + std::to_string(temp_name_count_++);
726   }
727 
pushFrametorch::jit::to_ir728   void pushFrame(Block* b, bool starts_def = false) {
729     if (starts_def) {
730       def_stack_.emplace_back();
731     }
732     environment_stack =
733         std::make_shared<Environment>(method, resolver, b, environment_stack);
734   }
popFrametorch::jit::to_ir735   std::shared_ptr<Environment> popFrame(bool ends_def = false) {
736     auto old_frame = environment_stack;
737     environment_stack = environment_stack->next;
738     if (ends_def) {
739       def_stack_.pop_back();
740     }
741     return old_frame;
742   }
743 
744   // If the graph might not return, add an implicit None return at the end
handleMaybeNoReturntorch::jit::to_ir745   void handleMaybeNoReturn(const Def& def, Block* block) {
746     auto decl_ret = def_stack_.back().declared_return_type_;
747     if (exit_blocks.count(block) == 0) {
748       auto decl_ret = def_stack_.back().declared_return_type_;
749       if (decl_ret && decl_ret != NoneType::get()) {
750         throw(
751             ErrorReport(def.range())
752             << "Function was not annotated as having type None, but does not "
753             << "return along all paths");
754       }
755       WithInsertPoint b(*block->nodes().end());
756       emitReturn(Return::create(
757           def.range(), Expr(Compound::create(TK_NONE, def.range(), {}))));
758     } else {
759       // if we haven't seen any return statements, but the graph block exits
760       // (the function always throws) then we accept the declared return type if
761       // it exists or set it to none
762       if (def_stack_.back().merged_return_type_ == nullptr) {
763         def_stack_.back().merged_return_type_ =
764             decl_ret != nullptr ? decl_ret : NoneType::get();
765       }
766     }
767   }
768 
emitDeftorch::jit::to_ir769   FunctionSchema emitDef(const Def& def, const Self* self, Block* block) {
770     auto schema = typeParser_.parseSchemaFromDef(def, bool(self));
771     // TODO need guards on init returning none
772     if (schema.returns().size() == 1) {
773       def_stack_.back().declared_return_type_ = schema.returns().at(0).type();
774     }
775     std::vector<Argument> arguments =
776         emitFormalArguments(def, self, schema, block);
777 
778     // body
779     auto stmts_list = def.statements();
780     emitStatements(stmts_list.begin(), stmts_list.end());
781     handleMaybeNoReturn(def, block);
782     std::vector<Argument> returns = {emitOutput(def.range(), schema, block)};
783     return {def.name().name(), "", std::move(arguments), std::move(returns)};
784   }
785 
786   // see [setstate type]
getTypeForSetStateArgtorch::jit::to_ir787   static TypePtr getTypeForSetStateArg(const Def& def, const Self* self) {
788     TORCH_CHECK(self, "Expected __setstate__ to have a `self` argument");
789     auto getstate = self->getClassType()->findMethod("__getstate__");
790     if (!getstate) {
791       throw(
792           ErrorReport(def.range())
793           << "`__setstate__` defined but not `__getstate__`. "
794           << "You must have both defined on a ScriptModule "
795           << "to customize serialization.\n"
796           << "Did you forget to use `@torch.jit.export`?");
797     }
798     getstate->ensure_defined();
799     return self->getClassType()
800         ->getMethod("__getstate__")
801         .getSchema()
802         .returns()
803         .at(0)
804         .type();
805   }
806 
807   // see [setstate type]
shouldDeriveSetStateTypetorch::jit::to_ir808   static bool shouldDeriveSetStateType(
809       const Def& def,
810       const FunctionSchema& schema) {
811     const bool noTypeAnnotations = std::all_of(
812         schema.arguments().begin(),
813         schema.arguments().end(),
814         [](const Argument& arg) { return arg.is_inferred_type(); });
815 
816     bool shouldInfer = def.name().name() == "__setstate__" && noTypeAnnotations;
817     if (!shouldInfer) {
818       return false;
819     }
820 
821     // Do some additional basic validation that the __setstate__ func is
822     // well-formed
823     TORCH_INTERNAL_ASSERT(def.name().name() == "__setstate__");
824     const auto numDeclParams = def.decl().params().size();
825     if (numDeclParams != 2) {
826       throw(
827           ErrorReport(def.range())
828           << "Expected 2 arguments for `__setstate__`, got: " << numDeclParams);
829     }
830     return true;
831   }
832 
emitFormalArgumentstorch::jit::to_ir833   std::vector<Argument> emitFormalArguments(
834       const Def& def,
835       const Self* self,
836       const FunctionSchema& schema,
837       Block* block) {
838     std::vector<Argument> arguments; // for schema
839     // inputs
840     auto it = def.decl().params().begin();
841     auto end = def.decl().params().end();
842     auto expected_annotation_size = def.decl().params().size();
843     if (self) {
844       expected_annotation_size--;
845     }
846     if (schema.arguments().size() != expected_annotation_size) {
847       throw(
848           ErrorReport(def.decl().params().range())
849           << "Number of type annotations for"
850           << " function parameters (" << schema.arguments().size() << ")"
851           << " does not match the number of parameters on the function ("
852           << expected_annotation_size << ")!");
853     }
854 
855     if (self) {
856       AT_ASSERT(it != end);
857       const auto& name = (*it).ident().name();
858       Value* new_input = block->addInput()->setDebugName(name);
859       environment_stack->setSugaredVar(
860           (*it).ident().range(),
861           name,
862           self->makeSugared(new_input),
863           /*annotated_type=*/nullptr);
864       arguments.emplace_back(name, new_input->type());
865       ++it;
866     }
867 
868     // [setstate type]
869     // __setstate__ is special, because if the user leaves it un-annotated we
870     // will derive the type for `state` from the output type of __getstate__.
871     // This is necessary so that we can allow submodules to appear in `state`.
872     bool shouldDeriveType = shouldDeriveSetStateType(def, schema);
873     size_t arg_annotation_idx = 0;
874     for (; it != end; ++it) {
875       auto& name = (*it).ident().name();
876       // Add the input to the graph
877       Value* new_input = block->addInput();
878       if (meaningfulName(name)) {
879         new_input->setDebugName(name);
880       }
881       // Record the type for the schema and set the Type on the Value*
882       auto arg = schema.arguments().at(arg_annotation_idx++);
883       if (shouldDeriveType) {
884         TORCH_INTERNAL_ASSERT(schema.arguments().size() == 1);
885         const auto& inferredStateType = getTypeForSetStateArg(def, self);
886         arg = arg.cloneWithType(inferredStateType);
887       }
888 
889       arguments.push_back(arg);
890       new_input->setType(arguments.back().type());
891 
892       // NB: set type of new_input before setVar call so the Store is
893       // typed appropriately
894       environment_stack->setVar((*it).ident().range(), name, new_input);
895     }
896     return arguments;
897   }
898 
emitOutputtorch::jit::to_ir899   Argument emitOutput(
900       const SourceRange& range,
901       const FunctionSchema& schema,
902       Block* block) {
903     // handleMaybeNoReturn ensures that merged_return_type_ is always set
904     auto ret_type = def_stack_.back().merged_return_type_;
905     TORCH_INTERNAL_ASSERT(ret_type);
906 
907     // in the ConvertToSSA pass, prim::ReturnStmts are lowered so that the
908     // correct return value is set. Until then, we have a correctly-typed
909     // placeholder return value. This is needed so that closures & graphs
910     // are correctly typed.
911     auto placeholder_return =
912         graph->insertNode(graph->createUninitialized(ret_type))->output();
913     block->registerOutput(placeholder_return);
914     return Argument("", def_stack_.back().merged_return_type_);
915   }
916 
emitStatementstorch::jit::to_ir917   void emitStatements(const List<Stmt>& statements) {
918     return emitStatements(statements.begin(), statements.end());
919   }
920 
921   // XXX: Right now closures are not generically implemented and are only used
922   // as an intermediate form for special tasks, like defining gradients or
923   // forked functions.
924   //
925   // There are several unfinished aspects that make them unusable generally
926   // 1. We do not have a type, ivalue, operator to represent prim::Closure, so
927   // closure_node has type None
928   // 2. There is no export logic for it yet, so it cannot be
929   // exported/python_printed
930   // 3. There is nothing preventing the assignment of already existing variables
931   // inside the closures
932   //    the changes to those variables will just get forgotten.
933   // 4. There is no parsing support in frontend.py, this is intentional since it
934   //    prevents people from accidentally using this feature.
935   //
936   // This function leaves in the graph something like:
937   //
938   //   %2 : None = prim::Closure()
939   //     block0():
940   //       %1 : Tensor = prim::DoSomething(%0)
941   //       -> (%1)
942   //
943   // A separate pass is required to erase this closure and replace it with
944   // something actually executable (see liftClosure and inlineForkedClosure).
emitClosuretorch::jit::to_ir945   std::shared_ptr<ClosureValue> emitClosure(
946       const std::function<void(Block*)>& emit_body) {
947     Node* closure_node = graph->insertNode(graph->create(prim::Closure, 1));
948     // it is not a real thing yet, so just say the type is None
949     closure_node->output()->setType(NoneType::get());
950     Block* block = closure_node->addBlock();
951     WithLoopStatus loop_guard(&loop_status_, LoopStatus::NOT_IN_LOOP);
952     {
953       WithInsertPoint guard(block);
954       pushFrame(block, /*starts_def=*/true);
955       emit_body(block);
956       popFrame(/*ends_def=*/true);
957     }
958     return std::make_shared<ClosureValue>(closure_node->output());
959   }
960 
emitClosuretorch::jit::to_ir961   void emitClosure(const Def& def) {
962     // invoked once the closure block is set as the environment
963     auto emit_body = [&](Block* closure_block) {
964       emitDef(
965           def,
966           nullptr,
967           closure_block); // ignore schema return, we just wont use it for now
968                           // since we never create a Method for the closure
969     };
970     auto closure_value = emitClosure(emit_body);
971     environment_stack->setSugaredVar(
972         def.name().range(),
973         def.name().name(),
974         closure_value,
975         /*annotated_type=*/nullptr);
976   }
977 
checkBreakContinuetorch::jit::to_ir978   void checkBreakContinue(
979       const SourceRange& loc,
980       const std::string& stmt_name) {
981     if (loop_status_ == LoopStatus::NOT_IN_LOOP) {
982       throw(
983           ErrorReport(loc) << "SyntaxError: '" << stmt_name << "'"
984                            << " outside loop");
985     } else if (loop_status_ == LoopStatus::IN_UNROLLED_LOOP) {
986       throw(
987           ErrorReport(loc)
988           << "Because we emit iteration over modulelists or tuples as "
989              "unrolled loops, we do not support break or continue inside the body of these loops");
990     }
991   }
992 
emitBreaktorch::jit::to_ir993   void emitBreak(const Break& stmt) {
994     checkBreakContinue(stmt.range(), "break");
995     auto break_node =
996         graph->create(prim::BreakStmt, {}, 0)->setSourceRange(stmt.range());
997     graph->insertNode(break_node);
998   }
999 
emitContinuetorch::jit::to_ir1000   void emitContinue(const Continue& stmt) {
1001     checkBreakContinue(stmt.range(), "continue");
1002     auto continue_node =
1003         graph->create(prim::ContinueStmt, {}, 0)->setSourceRange(stmt.range());
1004     graph->insertNode(continue_node);
1005   }
1006 
emitDeletetorch::jit::to_ir1007   void emitDelete(const Delete& stmt) {
1008     for (const auto& target : stmt.targets()) {
1009       if (target.kind() == TK_SUBSCRIPT) {
1010         Subscript subscript(target);
1011         const List<Expr>& subscript_exprs = subscript.subscript_exprs();
1012         if (subscript_exprs[0].kind() == TK_SLICE_EXPR) {
1013           throw(
1014               ErrorReport(target.range())
1015               << "del statements only support deletion at a single index, "
1016                  "slicing is not supported"
1017                  " (see https://github.com/pytorch/pytorch/issues/31430)");
1018         }
1019         const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1);
1020         const SourceRange& val_range = subscript.value().range();
1021         Value* idx = emitExpr(subscript_exprs[0]);
1022         Value* val = sv->asValue(val_range, method);
1023 
1024         // If val is a class instance, this is a method call to a type-specific
1025         // implementation of del defined in a __delitem__ method.
1026         if (auto cls = val->type()->cast<ClassType>()) {
1027           if (!cls->findMethod("__delitem__")) {
1028             throw(
1029                 ErrorReport(target.range())
1030                 << "Class does not define __delitem__");
1031           }
1032 
1033           // Use MethodValue to call the method to handle recursion.
1034           MethodValue(val, "__delitem__")
1035               .call(stmt.range(), method, {idx}, {}, 0);
1036         } else {
1037           auto node = graph->create(aten::Delete, {val, idx}, 0)
1038                           ->setSourceRange(target.range());
1039           graph->insertNode(node);
1040         }
1041       } else if (target.kind() == TK_VAR) {
1042         Var var(target);
1043         environment_stack->removeVar(var.name(), /*check_if_removed=*/true);
1044       } else {
1045         throw(
1046             ErrorReport(target.range())
1047             << "del statements are only supported for deleting"
1048                " list and dict items and variables");
1049       }
1050     }
1051   }
1052 
emitReturntorch::jit::to_ir1053   void emitReturn(const Return& stmt) {
1054     TypePtr declared_return_type =
1055         def_stack_.back().declared_return_type_; // nullptr if not annotated
1056     auto actual_return = emitExpr(stmt.expr(), declared_return_type);
1057 
1058     // result type is annotated, every return must convert to that type
1059     if (declared_return_type) {
1060       // this guard skips implicit conversion from None -> Tensor for the return
1061       // type. otherwise forgetting a return a function returning a tensor will
1062       // cause a None to be converted to a tensor.
1063       if (!(actual_return->type()->isSubtypeOf(*TensorType::get()) &&
1064             actual_return->type()->isSubtypeOf(*NoneType::get()))) {
1065         actual_return = tryConvertToType(
1066             stmt.range(),
1067             *graph,
1068             declared_return_type,
1069             actual_return,
1070             /*allow_conversions=*/true);
1071       }
1072       if (!actual_return->type()->isSubtypeOf(*declared_return_type)) {
1073         throw(
1074             ErrorReport(stmt.range())
1075             << "Return value was annotated as having type "
1076             << declared_return_type->repr_str() << " but is actually of type "
1077             << actual_return->type()->repr_str());
1078       }
1079     } else {
1080       declared_return_type = def_stack_.back().merged_return_type_;
1081       if (!declared_return_type) {
1082         declared_return_type = actual_return->type();
1083       }
1084       auto merged_return_type =
1085           unifyTypes(declared_return_type, actual_return->type());
1086       if (!merged_return_type) {
1087         throw(
1088             ErrorReport(stmt.range())
1089             << "Previous return statement returned a value of type "
1090             << declared_return_type->repr_str()
1091             << " but this return statement returns a value of type "
1092             << actual_return->type()->repr_str());
1093       }
1094       declared_return_type = merged_return_type.value();
1095     }
1096     AT_ASSERT(declared_return_type);
1097 
1098     def_stack_.back().merged_return_type_ = declared_return_type;
1099 
1100     // If the annotated return type is Any and the result type is not Any,
1101     // cast the result to Any to facilitate type unification between return
1102     // statements on different code paths (e.g. different branches of an if,
1103     // body and containing scope of a loop).
1104     if (declared_return_type == AnyType::get() &&
1105         actual_return->type() != AnyType::get()) {
1106       actual_return =
1107           graph->insertUncheckedCast(actual_return, declared_return_type);
1108     }
1109 
1110     graph->insertNode(graph->create(prim::ReturnStmt, {actual_return}, 0));
1111     exit_blocks.insert(environment_stack->block());
1112   }
1113 
emitStatementstorch::jit::to_ir1114   void emitStatements(
1115       List<Stmt>::const_iterator begin,
1116       List<Stmt>::const_iterator end) {
1117     for (; begin != end; ++begin) {
1118       auto stmt = *begin;
1119       ErrorReport::CallStack::update_pending_range(stmt.range());
1120       switch (stmt.kind()) {
1121         case TK_IF:
1122           emitIf(If(stmt));
1123           break;
1124         case TK_WHILE:
1125           emitWhile(While(stmt));
1126           break;
1127         case TK_FOR:
1128           emitFor(For(stmt));
1129           break;
1130         case TK_ASSIGN:
1131           emitAssignment(Assign(stmt));
1132           break;
1133         case TK_AUG_ASSIGN:
1134           emitAugAssignment(AugAssign(stmt));
1135           break;
1136         case TK_EXPR_STMT: {
1137           auto expr = ExprStmt(stmt).expr();
1138           emitSugaredExpr(expr, 0);
1139         } break;
1140         case TK_RAISE:
1141           emitRaise(Raise(stmt));
1142           break;
1143         case TK_ASSERT:
1144           emitAssert(Assert(stmt));
1145           break;
1146         case TK_RETURN: {
1147           emitReturn(Return(stmt));
1148         } break;
1149         case TK_CONTINUE: {
1150           emitContinue(Continue(stmt));
1151         } break;
1152         case TK_BREAK: {
1153           emitBreak(Break(stmt));
1154         } break;
1155         case TK_PASS:
1156           // Emit nothing for pass
1157           break;
1158         case TK_DEF:
1159           emitClosure(Def(stmt));
1160           break;
1161         case TK_DELETE:
1162           emitDelete(Delete(stmt));
1163           break;
1164         case TK_WITH:
1165           emitWith(With(stmt));
1166           break;
1167         default:
1168           throw(
1169               ErrorReport(stmt)
1170               << "Unrecognized statement kind " << kindToString(stmt.kind()));
1171       }
1172       // Found an exit statement in this block. The remaining statements aren't
1173       // reachable so we don't emit them.
1174       if (exit_blocks.count(environment_stack->block()))
1175         return;
1176     }
1177   }
1178 
findIsNoneRefinementstorch::jit::to_ir1179   RefinementSet findIsNoneRefinements(
1180       const Expr& lhs,
1181       Value* lhs_value,
1182       const Expr& rhs,
1183       Value* rhs_value,
1184       int tok) {
1185     if (rhs.kind() != TK_NONE && lhs.kind() == TK_NONE) {
1186       // make 'None is var' into 'var is None'
1187       return findIsNoneRefinements(rhs, rhs_value, lhs, lhs_value, tok);
1188     }
1189     if (rhs.kind() != TK_NONE || lhs.kind() != TK_VAR) {
1190       return {};
1191     }
1192     // statement must be var {is, is not} None
1193     const std::string& name = Var(lhs).name().name();
1194     // While it should in theory be possible to specialize
1195     // the `x is None` to know x has type NoneType, we have previously
1196     // not done this. Unfortunately, doing this will make the type None
1197     // propagate further in all loaded models. The handling of
1198     // unwrap_optional will fail in these cases since export did
1199     // not expect that the input would be none and an unannotated None.
1200     // To enable this, we need to (1) implement a real casting operator
1201     // annotated(T, X) that stays in the graph and does the cast
1202     // and (2) only enable this OPTIONAL_NONE when loading newer
1203     // graphs because it is incompatible with older graphs.
1204     // Refinement none(name, RefinementKind::OPTIONAL_NONE);
1205     if (const auto optional_type = lhs_value->type()->cast<OptionalType>()) {
1206       Refinement present(name, optional_type->getElementType());
1207       if (tok == TK_IS) {
1208         return RefinementSet({}, {present});
1209       } else { // TK_ISNOT
1210         return RefinementSet({present}, {});
1211       }
1212     }
1213     if (const auto union_type = lhs_value->type()->cast<UnionType>()) {
1214       std::vector<TypePtr> to_subtract{NoneType::get()};
1215       std::optional<TypePtr> remaining =
1216           union_type->subtractTypeSet(to_subtract);
1217       std::vector<Refinement> all_present;
1218       if (remaining) {
1219         Refinement present{name, *remaining};
1220         all_present.push_back(std::move(present));
1221       }
1222       if (tok == TK_IS) {
1223         return RefinementSet({}, all_present);
1224       } else { // TK_ISNOT
1225         return RefinementSet(all_present, {});
1226       }
1227     }
1228     return RefinementSet();
1229   }
1230 
emitCondExprtorch::jit::to_ir1231   CondValue emitCondExpr(const Expr& expr) {
1232     switch (expr.kind()) {
1233       case TK_AND:
1234       case TK_OR: {
1235         auto binop = BinOp(expr);
1236         return emitShortCircuitLogical(
1237             binop.range(), binop.lhs(), binop.rhs(), expr.kind() == TK_OR);
1238       }
1239       case TK_NOT: {
1240         CondValue v = emitCondExpr(Expr(expr.tree()->trees()[0]));
1241         Value* result = emitBuiltinCall(
1242             expr.range(), *graph, aten::__not__, {v.value()}, {});
1243         std::optional<bool> static_if;
1244         if (v.staticIf()) {
1245           static_if = !*v.staticIf();
1246         }
1247         return CondValue(result, v.refinements().Not(), static_if);
1248       } break;
1249       case TK_IS:
1250       case TK_ISNOT: {
1251         // meta programming on AST for is/is not cases and emit branches base on
1252         auto cond_op = BinOp(expr);
1253         Value* lhs_val = emitExpr(cond_op.lhs());
1254         Value* rhs_val = emitExpr(cond_op.rhs());
1255 
1256         auto lhs_none = canBeNone(lhs_val);
1257         auto rhs_none = canBeNone(rhs_val);
1258 
1259         // Dispatch logic (A: ALWAYS, N: NEVER, M: MAYBE):
1260         //
1261         // AA, -> statically IS always holds, IS_NOT never holds
1262         // AN , NA-> statically IS_NOT always holds, IS never holds
1263         // MA, MM, MN, NM, NN, AM -> cannot prove anything statically
1264         bool its_is = expr.kind() == TK_IS;
1265         if (lhs_none == ALWAYS && rhs_none == ALWAYS) {
1266           return CondValue(*graph, expr.range(), its_is, {});
1267         } else if (
1268             (lhs_none == ALWAYS && rhs_none == NEVER) ||
1269             (lhs_none == NEVER && rhs_none == ALWAYS)) {
1270           // lhs_val/rhs_val with A/M: only emit never_none_branch
1271           return CondValue(*graph, expr.range(), !its_is, {});
1272         } else {
1273           auto kind = getNodeKind(expr.kind(), expr.get()->trees().size());
1274           Value* cond_value = emitBuiltinCall(
1275               expr.get()->range(),
1276               *method.graph(),
1277               kind,
1278               {lhs_val, rhs_val},
1279               {});
1280           auto refinements = RefinementSet(findIsNoneRefinements(
1281               cond_op.lhs(), lhs_val, cond_op.rhs(), rhs_val, expr.kind()));
1282           return CondValue(cond_value, refinements, std::nullopt);
1283         }
1284       } break;
1285       default: {
1286         if (expr.kind() == TK_APPLY) {
1287           auto apply = Apply(expr);
1288           auto callee = Apply(expr).callee();
1289           if (callee.kind() == TK_VAR) {
1290             if (Var(callee).name().name() == "isinstance") {
1291               checkApplyNumInputs(apply, 2);
1292               return emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
1293             }
1294             if (Var(callee).name().name() == "hasattr") {
1295               checkApplyNumInputs(apply, 2);
1296               return emitHasAttr(apply.inputs()[0], apply.inputs()[1]);
1297             }
1298           }
1299           auto sv = emitSugaredExpr(apply.callee(), 1);
1300           auto loc = apply.callee().range();
1301           if (auto special_form = dynamic_cast<SpecialFormValue*>(sv.get())) {
1302             if (special_form->form() == prim::isinstance) {
1303               checkApplyNumInputs(apply, 2);
1304               return emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
1305             }
1306           }
1307         }
1308         auto expr_out = emitToBool(expr.range(), emitExpr(expr));
1309         std::optional<bool> static_if = std::nullopt;
1310         auto kind = expr_out->node()->kind();
1311         if (kind == aten::is_scripting) {
1312           static_if = true;
1313         } else if (kind == aten::has_torch_function) {
1314           static_if = false;
1315         }
1316         // MetaCompile on boolean literals and constants
1317         if (auto maybe_ivalue = toIValue(expr_out)) {
1318           static_if = maybe_ivalue->toBool();
1319         }
1320         return CondValue(expr_out, RefinementSet({}), static_if);
1321       } break;
1322     }
1323   }
1324 
emitSingleIfBranchtorch::jit::to_ir1325   std::shared_ptr<Environment> emitSingleIfBranch(
1326       Block* b,
1327       const List<Stmt>& branch,
1328       const RefinementSet& refinements) {
1329     pushFrame(b);
1330     WithInsertPoint guard(b);
1331     insertRefinements(branch.range(), refinements);
1332     emitStatements(branch);
1333     return popFrame();
1334   }
1335 
createtorch::jit::to_ir1336   Node* create(Symbol kind, const SourceRange& loc, size_t n_outputs) {
1337     return graph->create(kind, n_outputs)->setSourceRange(loc);
1338   }
1339 
emitTernaryIftorch::jit::to_ir1340   Value* emitTernaryIf(
1341       const TernaryIf& expr,
1342       const TypePtr& type_hint = nullptr) {
1343     CondValue cond_value = emitCondExpr(expr.cond());
1344     // If the cond expr is a static value, then we metacompile the `if`
1345     // statemement and only emit true or false branch
1346     if (cond_value.staticIf()) {
1347       if (*cond_value.staticIf()) {
1348         return emitExpr(expr.true_expr(), type_hint);
1349       } else {
1350         return emitExpr(expr.false_expr(), type_hint);
1351       }
1352     }
1353     auto true_expr = [&] { return emitExpr(expr.true_expr(), type_hint); };
1354     auto false_expr = [&] { return emitExpr(expr.false_expr(), type_hint); };
1355     return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
1356   }
1357 
1358   template <class F1, class F2, class F3>
refineAndSetUnionTypeHintOrPopulateCandidatesVectortorch::jit::to_ir1359   void refineAndSetUnionTypeHintOrPopulateCandidatesVector(
1360       const TypePtr& type_hint,
1361       TypePtr* refined_type_hint_ptr,
1362       std::vector<TypePtr>* all_candidates,
1363       const std::string& match_repr,
1364       const Expr& src,
1365       const F1& type_match,
1366       const F2& do_if_match,
1367       const F3& do_if_anytype,
1368       bool is_dict_constructor = false) {
1369     if (auto union_type_hint = (*refined_type_hint_ptr)->cast<UnionType>()) {
1370       // `candidate_types` holds all List types that were in the Union
1371       // annotation
1372       std::vector<TypePtr> candidate_types;
1373 
1374       std::copy_if(
1375           union_type_hint->containedTypes().begin(),
1376           union_type_hint->containedTypes().end(),
1377           std::back_inserter(candidate_types),
1378           [&](TypePtr type_ptr) { return type_match(type_ptr); });
1379 
1380       if (!is_dict_constructor && candidate_types.empty()) {
1381         throw(
1382             ErrorReport(src)
1383             << "Expected an Union type annotation "
1384             << "with an inner " << match_repr << " type, but got "
1385             << (*refined_type_hint_ptr)->repr_str());
1386       } else if (candidate_types.size() == 1) {
1387         // The Union only had a single type of the container we want to
1388         // match, so we can unconditionally refine it to that type
1389         (*refined_type_hint_ptr) = candidate_types[0];
1390       } else {
1391         // We can't refine the Union yet, since it contains multiple
1392         // types of the container we want to match, but we do at least
1393         // have a list of possiblee types (e.g. `Union[List[int],
1394         // List[str], float, str]` -> candidates={List[int], List[str]})
1395         (*all_candidates) = std::move(candidate_types);
1396       }
1397     } else if (
1398         auto optional_type_hint =
1399             (*refined_type_hint_ptr)->cast<OptionalType>()) {
1400       (*refined_type_hint_ptr) = optional_type_hint->getElementType();
1401     }
1402 
1403     // This case handles code like `dict([(x, y), (a, b)])` that would
1404     // otherwise fail the following error checks
1405     if (is_dict_constructor) {
1406       return;
1407     }
1408 
1409     // If we had any annotation that was NOT a Union that can hold more
1410     // than one type of the container we want to match
1411     if (all_candidates->empty()) {
1412       if (type_match(*refined_type_hint_ptr)) {
1413         do_if_match();
1414       } else if ((*refined_type_hint_ptr)->kind() == AnyType::Kind) {
1415         do_if_anytype();
1416       } else {
1417         throw(
1418             ErrorReport(src) << "Expected an annotation of type " << match_repr
1419                              << " but got " << type_hint->repr_str());
1420       }
1421     }
1422   }
1423 
refineAndSetListTypeHintFromCandidatesVectortorch::jit::to_ir1424   void refineAndSetListTypeHintFromCandidatesVector(
1425       const std::vector<TypePtr>& all_candidates,
1426       const TypePtr& type_hint,
1427       TypePtr* refined_type_hint_ptr,
1428       const TypePtr& unified_elem_type,
1429       const Expr& src) {
1430     TypePtr greatest_elem_type = nullptr;
1431     std::for_each(
1432         all_candidates.begin(),
1433         all_candidates.end(),
1434         [&](const TypePtr& candidate) {
1435           auto candidate_elem_type =
1436               candidate->expect<ListType>()->getElementType();
1437           if (unified_elem_type->isSubtypeOf(candidate_elem_type)) {
1438             if (!greatest_elem_type) {
1439               greatest_elem_type = candidate_elem_type;
1440             } else {
1441               greatest_elem_type =
1442                   *(unifyTypes(greatest_elem_type, candidate_elem_type));
1443             }
1444           }
1445         });
1446     if (!greatest_elem_type) {
1447       std::stringstream vector_repr;
1448       for (size_t i = 0; i < all_candidates.size(); ++i) {
1449         if (i > 0 && all_candidates.size() > 2) {
1450           vector_repr << ", ";
1451         }
1452         if (i != 0 && i == all_candidates.size() - 1) {
1453           vector_repr << " or ";
1454         }
1455         vector_repr << all_candidates[i]->repr_str();
1456       }
1457       throw(
1458           ErrorReport(src) << "Union type annotation `" << type_hint->repr_str()
1459                            << "` can hold " << vector_repr.str()
1460                            << ", but none of "
1461                            << "those types match the types of the given list "
1462                            << "elements, which were unified to "
1463                            << unified_elem_type->repr_str());
1464     } else {
1465       (*refined_type_hint_ptr) = ListType::create(greatest_elem_type);
1466       ;
1467     }
1468   }
1469 
refineAndSetDictTypeHintFromCandidatesVectortorch::jit::to_ir1470   void refineAndSetDictTypeHintFromCandidatesVector(
1471       const std::vector<TypePtr>& all_candidates,
1472       const TypePtr& type_hint,
1473       TypePtr* refined_type_hint_ptr,
1474       const TypePtr& known_key_type,
1475       const TypePtr& known_value_type,
1476       const Expr& src) {
1477     TypePtr candidate_key_type = nullptr;
1478     TypePtr candidate_value_type = nullptr;
1479     TypePtr candidate = nullptr;
1480 
1481     for (const auto& current_candidate : all_candidates) {
1482       auto current_key_type =
1483           current_candidate->expect<DictType>()->getKeyType();
1484       auto current_value_type =
1485           current_candidate->expect<DictType>()->getValueType();
1486 
1487       if (known_key_type->isSubtypeOf(current_key_type) &&
1488           known_value_type->isSubtypeOf(current_value_type)) {
1489         if (!candidate ||
1490             (candidate_key_type->isSubtypeOf(current_key_type) &&
1491              candidate_value_type->isSubtypeOf(current_value_type))) {
1492           candidate_key_type = current_key_type;
1493           candidate_value_type = current_value_type;
1494           candidate = current_candidate;
1495         }
1496       }
1497     }
1498 
1499     if (!candidate) {
1500       std::stringstream vector_repr;
1501       for (size_t i = 0; i < all_candidates.size(); ++i) {
1502         if (i > 0 && all_candidates.size() > 2) {
1503           vector_repr << ", ";
1504         }
1505         if (i != 0 && i == all_candidates.size() - 1) {
1506           vector_repr << " or ";
1507         }
1508         vector_repr << all_candidates[i]->repr_str();
1509       }
1510       throw(
1511           ErrorReport(src) << "Union type annotation `" << type_hint->repr_str()
1512                            << "` can hold " << vector_repr.str()
1513                            << ", but none of "
1514                            << "those dict types can hold the types of the given"
1515                            << " keys and values, which were unified to Dict["
1516                            << known_key_type->repr_str() << ", "
1517                            << known_value_type->repr_str());
1518     } else {
1519       (*refined_type_hint_ptr) = candidate;
1520     }
1521   }
1522 
emitListComprehensiontorch::jit::to_ir1523   Value* emitListComprehension(const ListComp& lc, const TypePtr& type_hint) {
1524     const auto loc = lc.range();
1525     const auto targets_list = List<Expr>::create(lc.range(), {lc.target()});
1526     const auto itrs = List<Expr>::create(lc.range(), {lc.iter()});
1527 
1528     // If there is no type hint, and this is emitted over an iterable that is
1529     // unrolled and of length 0, then we emit a List of tensors
1530     Value* list_value = graph->insertNode(graph->create(prim::ListConstruct, 1))
1531                             ->output()
1532                             ->setType(ListType::ofTensors());
1533 
1534     TypePtr refined_type_hint = type_hint;
1535     std::vector<TypePtr> all_candidates = {};
1536 
1537     if (refined_type_hint) {
1538       auto do_if_type_match = [&]() { list_value->setType(refined_type_hint); };
1539 
1540       auto type_match = [&](const TypePtr& t) {
1541         return t->isSubtypeOf(AnyListType::get());
1542       };
1543 
1544       refineAndSetUnionTypeHintOrPopulateCandidatesVector(
1545           type_hint,
1546           &refined_type_hint,
1547           &all_candidates,
1548           "List",
1549           lc,
1550           type_match,
1551           do_if_type_match,
1552           do_if_type_match);
1553     }
1554 
1555     bool seen_first_elem = false;
1556 
1557     // A list comprehension introduces its own scope
1558     Node* n =
1559         graph->insertNode(create(prim::ComprehensionScope, lc.range(), 0));
1560     auto* comprehension_block = n->addBlock();
1561     pushFrame(comprehension_block);
1562     WithInsertPoint guard(comprehension_block);
1563     auto emit_body = [&]() {
1564       Value* out = emitExpr(lc.elt());
1565 
1566       // If we didn't have a type annotation, the type of the list would
1567       // be set to `Tensor`. We don't want to unify this default type
1568       // with the actual elements in the list, so let the type begin as
1569       // the first element in the list
1570       if (!seen_first_elem) {
1571         list_value->setType(ListType::create(out->type()));
1572         seen_first_elem = true;
1573       }
1574 
1575       const auto elem_type_hint =
1576           refined_type_hint && refined_type_hint->kind() == ListType::Kind
1577           ? refined_type_hint->cast<ListType>()->getElementType()
1578           : nullptr;
1579 
1580       std::optional<TypePtr> unified_elem_type = unifyTypes(
1581           list_value->type()->expect<ListType>()->getElementType(),
1582           out->type(),
1583           /*default_to_union=*/true,
1584           elem_type_hint);
1585 
1586       // Case: The list comprehension generated heterogenous values,
1587       // and we don't have a type hint to suggest that this is what the
1588       // user expected
1589       if (!type_hint && (*unified_elem_type)->isUnionType()) {
1590         TORCH_WARN(
1591             "List consists of heterogeneous types, which means",
1592             " that it has been typed as containing ",
1593             (*unified_elem_type)->repr_str(),
1594             ". To use any of the "
1595             "values in this List, it will be necessary to add an "
1596             "`assert isinstance` statement before first use to trigger "
1597             "type refinement. The first non-matching element was typed",
1598             " as ",
1599             out->type()->repr_str(),
1600             ", while the elements "
1601             " before it were ",
1602             list_value->type()
1603                 ->expect<ListType>()
1604                 ->getElementType()
1605                 ->repr_str(),
1606             "\n",
1607             lc.range().str());
1608       }
1609 
1610       // Case: We had an annotation that we were able to narrow down to
1611       // a single ListType, but the most recently generated element in
1612       // the list comprehension doesn't match that annotation
1613       if (all_candidates.empty() && refined_type_hint &&
1614           !(*unified_elem_type)
1615                ->isSubtypeOf(*refined_type_hint->expectRef<ListType>()
1616                                   .getElementType())) {
1617         throw(
1618             ErrorReport(lc)
1619             << "List type annotation `" << refined_type_hint->repr_str()
1620             << "` did not match the types of the given list elements,"
1621             << " which were unified to " << (*unified_elem_type)->repr_str());
1622       }
1623 
1624       if (!all_candidates.empty()) {
1625         // If we had a Union type annotation that could hold more than
1626         // one different type of `List`
1627         refineAndSetListTypeHintFromCandidatesVector(
1628             all_candidates,
1629             type_hint,
1630             &refined_type_hint,
1631             *unified_elem_type,
1632             lc);
1633       } else if (!refined_type_hint) {
1634         refined_type_hint = ListType::create(*unified_elem_type);
1635       }
1636 
1637       list_value->setType(refined_type_hint);
1638       out->setType(refined_type_hint->expect<ListType>()->getElementType());
1639 
1640       NamedValue self = NamedValue(loc, "self", list_value);
1641       NamedValue input = NamedValue(loc, "", out);
1642       emitBuiltinCall(loc, *graph, aten::append, {input}, {}, self);
1643     };
1644     emitFor(targets_list, itrs, loc, emit_body);
1645     popFrame();
1646     return list_value;
1647   }
1648 
emitDictComprehensiontorch::jit::to_ir1649   Value* emitDictComprehension(const DictComp& dc, const TypePtr& type_hint) {
1650     const auto loc = dc.range();
1651     const auto targets_list = List<Expr>::create(dc.range(), {dc.target()});
1652     const auto itrs = List<Expr>::create(dc.range(), {dc.iter()});
1653 
1654     Value* dict_value =
1655         graph->insertNode(graph->create(prim::DictConstruct, 1))->output();
1656 
1657     // Set the default type to be Dict[str, Tensor]
1658     dict_value->setType(DictType::create(StringType::get(), TensorType::get()));
1659 
1660     TypePtr refined_type_hint = type_hint;
1661     TypePtr annotated_union_type =
1662         type_hint && type_hint->isUnionType() ? type_hint : nullptr;
1663 
1664     std::vector<TypePtr> all_candidates = {};
1665 
1666     if (refined_type_hint) {
1667       auto type_match = [&](const TypePtr& t) {
1668         return t->kind() == DictType::Kind;
1669       };
1670 
1671       auto do_if_match = [&]() { dict_value->setType(refined_type_hint); };
1672 
1673       refineAndSetUnionTypeHintOrPopulateCandidatesVector(
1674           type_hint,
1675           &refined_type_hint,
1676           &all_candidates,
1677           "Dict",
1678           dc,
1679           type_match,
1680           do_if_match,
1681           do_if_match);
1682     }
1683 
1684     TypePtr first_generated_key_type = nullptr;
1685     TypePtr first_generated_value_type = nullptr;
1686 
1687     // A dict comprehension introduces its own scope. No variable assigned
1688     // may leak into the rest of the graph
1689     Node* n =
1690         graph->insertNode(create(prim::ComprehensionScope, dc.range(), 0));
1691     auto* comprehension_block = n->addBlock();
1692     pushFrame(comprehension_block);
1693     WithInsertPoint guard(comprehension_block);
1694     auto emit_body = [&]() {
1695       auto k = emitExpr(dc.key());
1696       auto v = emitExpr(dc.value());
1697 
1698       // If we didn't have a type annotation, the type of the dict would
1699       // be set to `(str, Tensor)`. We don't want to unify this default
1700       // type with the actual elements in the dict, so let the type
1701       // begin as the first element in the dict
1702       if (k->type()->kind() == UnionType::Kind) {
1703         throw(
1704             ErrorReport(dc)
1705             << "Dicts may only contain homogeneous keys, but the type of "
1706             << "the first generated key was " << k->type()->repr_str());
1707       } else if (
1708           first_generated_key_type && first_generated_key_type != k->type()) {
1709         // Values can be heterogenous, so we only need to check that the
1710         // key types are all the same
1711         throw(
1712             ErrorReport(dc)
1713             << "Dicts may only contain homogeneous keys. Expected "
1714             << "dict comprehension to generate type "
1715             << first_generated_key_type->repr_str() << ", but got "
1716             << k->type()->repr_str());
1717       } else {
1718         dict_value->setType(DictType::create(k->type(), v->type()));
1719         first_generated_key_type = k->type();
1720         first_generated_value_type = v->type();
1721       }
1722 
1723       // If we had any annotation OTHER THAN a Union that can hold more
1724       // than one type of Dict
1725       if (refined_type_hint && all_candidates.empty()) {
1726         DictTypePtr dict_type_hint = refined_type_hint->expect<DictType>();
1727 
1728         std::stringstream ss;
1729         std::stringstream err;
1730 
1731         bool is_key_subtype =
1732             k->type()->isSubtypeOfExt(*dict_type_hint->getKeyType(), &ss);
1733 
1734         if (!is_key_subtype) {
1735           err << "Dict type annotation `" << dict_type_hint->repr_str()
1736               << "` did not match the "
1737               << "type of an actual key type `" << k->type()->repr_str()
1738               << "`\n"
1739               << ss.str();
1740         }
1741 
1742         ss.str(std::string());
1743         bool is_value_subtype =
1744             v->type()->isSubtypeOfExt(*dict_type_hint->getValueType(), &ss);
1745 
1746         if (!is_value_subtype) {
1747           err << "Dict type annotation `" << dict_type_hint->repr_str()
1748               << "` did not match the "
1749               << "type of an actual value type `" << v->type()->repr_str()
1750               << "`\n"
1751               << ss.str();
1752         }
1753 
1754         if (!is_key_subtype || !is_value_subtype) {
1755           throw(ErrorReport(dc) << err.str());
1756         }
1757       }
1758 
1759       const TypePtr value_type_hint =
1760           refined_type_hint && refined_type_hint->kind() == DictType::Kind
1761           ? refined_type_hint->expect<DictType>()->getValueType()
1762           : nullptr;
1763 
1764       std::optional<TypePtr> unified_value_type = unifyTypes(
1765           first_generated_value_type,
1766           v->type(),
1767           /*default_to_union=*/true,
1768           value_type_hint);
1769 
1770       if (!type_hint && (*unified_value_type)->isUnionType()) {
1771         TORCH_WARN(
1772             "Dict values consist of heterogeneous types, which means",
1773             " that they have been typed as being ",
1774             (*unified_value_type)->repr_str(),
1775             ". To use any of the "
1776             "values in this dict, it will be necessary to add an "
1777             "`assert isinstance` statement before first use to trigger "
1778             "type refinement. The first non-matching element was typed",
1779             " as ",
1780             v->type()->repr_str(),
1781             ", while the elements "
1782             " before it were ",
1783             first_generated_value_type->repr_str(),
1784             "\n",
1785             dc.range().str());
1786       }
1787 
1788       if (type_hint) {
1789         if (type_hint->kind() == DictType::Kind) {
1790           dict_value->setType(type_hint);
1791           k->setType(type_hint->expect<DictType>()->getKeyType());
1792           v->setType(type_hint->expect<DictType>()->getValueType());
1793         } else {
1794           if (!all_candidates.empty()) {
1795             refineAndSetDictTypeHintFromCandidatesVector(
1796                 all_candidates,
1797                 type_hint,
1798                 &refined_type_hint,
1799                 k->type(),
1800                 *unified_value_type,
1801                 dc);
1802           }
1803           dict_value->setType(refined_type_hint);
1804           k->setType(refined_type_hint->expect<DictType>()->getKeyType());
1805           v->setType(refined_type_hint->expect<DictType>()->getValueType());
1806         }
1807       } else {
1808         dict_value->setType(DictType::create(k->type(), *unified_value_type));
1809       }
1810 
1811       NamedValue self = NamedValue(loc, "self", dict_value);
1812       NamedValue input_k = NamedValue(loc, "", k);
1813       NamedValue input_v = NamedValue(loc, "", v);
1814       emitBuiltinCall(
1815           loc, *graph, aten::_set_item, {self, input_k, input_v}, {});
1816     };
1817     emitFor(targets_list, itrs, loc, emit_body);
1818     popFrame();
1819 
1820     if (annotated_union_type) {
1821       Node* n =
1822           graph->insertNode(graph->create(prim::unchecked_cast, {dict_value}));
1823       n->output()->setType(std::move(annotated_union_type));
1824       dict_value = n->output();
1825     }
1826 
1827     return dict_value;
1828   }
1829 
1830   // Insert subtyping refinements
insertRefinementstorch::jit::to_ir1831   void insertRefinements(const SourceRange& loc, const RefinementSet& ref) {
1832     for (const Refinement& r : ref.activeRefinements()) {
1833       Value* v = environment_stack->getVar(r.identifier(), loc);
1834       Value* new_v = graph->insertUncheckedCast(v, r.type());
1835       environment_stack->setVar(loc, r.identifier(), new_v);
1836     }
1837   }
1838 
emitShortCircuitLogicaltorch::jit::to_ir1839   CondValue emitShortCircuitLogical(
1840       const SourceRange& loc,
1841       const Expr& first_expr,
1842       const Expr& second_expr,
1843       bool is_or) {
1844     CondValue lhs = emitCondExpr(first_expr);
1845     // if the continue expr in the short circuit is not evaluated,
1846     // than the const expression is False if the short circuit
1847     // is an `and` and True if the short circuit is an `or`.
1848     // `False and expr` -> False, `True or expr` -> True
1849     //
1850     // inserting it as a constant makes optimization easier
1851 
1852     // if it's an OR the first expr is emitted in the true branch
1853     // and the second expr in the false branch, if it's an AND the opposite
1854     auto get_const_expr = [&] { return graph->insertConstant(is_or, loc); };
1855 
1856     std::optional<CondValue> rhs;
1857     auto get_continue_expr = [&] {
1858       rhs = emitCondExpr(second_expr);
1859       return rhs->value();
1860     };
1861 
1862     // if this is an OR, eval second expression if first expr is False
1863     // If this is an AND, eval second expression if first expr is True
1864     Value* new_result = nullptr;
1865     std::optional<RefinementSet> refinements;
1866     std::optional<bool> static_if;
1867     if (is_or) {
1868       new_result = emitIfExpr(loc, lhs, get_const_expr, get_continue_expr);
1869       refinements = lhs.refinements().Or(rhs->refinements());
1870       if ((lhs.staticIf() && *lhs.staticIf()) ||
1871           (rhs->staticIf() && *rhs->staticIf())) {
1872         static_if = true;
1873       } else if (lhs.staticIf() && rhs->staticIf()) {
1874         static_if = *lhs.staticIf() || *rhs->staticIf();
1875       }
1876     } else {
1877       new_result = emitIfExpr(loc, lhs, get_continue_expr, get_const_expr);
1878       refinements = lhs.refinements().And(rhs->refinements());
1879       if (((lhs.staticIf() && !*lhs.staticIf()) ||
1880            (rhs->staticIf() && !*rhs->staticIf()))) {
1881         static_if = false;
1882       } else if (lhs.staticIf() && rhs->staticIf()) {
1883         static_if = *lhs.staticIf() && *rhs->staticIf();
1884       }
1885     }
1886     return CondValue(new_result, std::move(*refinements), static_if);
1887   }
1888 
emitIfExprtorch::jit::to_ir1889   Value* emitIfExpr(
1890       const SourceRange& range,
1891       const CondValue& cond_value,
1892       const std::function<Value*()>& true_expr,
1893       const std::function<Value*()>& false_expr) {
1894     Node* n = graph->insertNode(create(prim::If, range, 0));
1895     n->addInput(cond_value.value());
1896     auto* true_block = n->addBlock();
1897     auto* false_block = n->addBlock();
1898 
1899     auto emit_if_expr = [this, &range](
1900                             Block* b,
1901                             const RefinementSet& refinements,
1902                             const std::function<Value*()>& expr_value) {
1903       pushFrame(b);
1904       WithInsertPoint guard(b);
1905       insertRefinements(range, refinements);
1906       Value* out_val = expr_value();
1907       b->registerOutput(out_val);
1908       popFrame();
1909     };
1910 
1911     emit_if_expr(true_block, cond_value.refinements(), true_expr);
1912     emit_if_expr(false_block, cond_value.refinements().Not(), false_expr);
1913 
1914     auto true_type = true_block->outputs().at(0)->type();
1915     auto false_type = false_block->outputs().at(0)->type();
1916     auto unified = unifyTypes(true_type, false_type);
1917     if (!unified) {
1918       throw(
1919           ErrorReport(range)
1920           << "if-expression's true branch has type " << true_type->repr_str()
1921           << " but false branch has type " << false_type->repr_str());
1922     }
1923 
1924     // Add op outputs
1925     auto expr_value = n->addOutput()->setType(*unified); // Resulting value
1926 
1927     return expr_value;
1928   }
emitToBooltorch::jit::to_ir1929   Value* emitToBool(const SourceRange& loc, Value* v) {
1930     Value* out = nullptr;
1931     try {
1932       auto bool_cast = environment_stack->getSugaredVar("bool", loc);
1933       out = asSimple(bool_cast->call(loc, method, {v}, {}, 0));
1934     } catch (...) {
1935       throw(
1936           ErrorReport(loc) << "Could not cast value of type "
1937                            << v->type()->repr_str() << " to bool");
1938     }
1939     if (!out) {
1940       throw(
1941           ErrorReport(loc) << "Could not cast value of type "
1942                            << v->type()->repr_str() << " to bool");
1943     }
1944     // cast value not response for checking output type
1945     if (!out->type()->isSubtypeOf(*BoolType::get())) {
1946       throw(
1947           ErrorReport(loc)
1948           << "expected a bool expression for condition but found "
1949           << out->type()->repr_str());
1950     }
1951     return out;
1952   }
1953 
emitIfElseBlockstorch::jit::to_ir1954   void emitIfElseBlocks(
1955       const SourceRange& loc,
1956       const CondValue& cond_value,
1957       const List<Stmt>& trueBranch,
1958       const List<Stmt>& falseBranch) {
1959     // this is a static if statement: that is, it contains a subset
1960     // of operators where we are willing to specialize the if statement
1961     // to be only the true or false branch when the condition is statically
1962     // known. This is used to meta-program modules, for instance, when a
1963     // submodule is absent, an is None check can be used to ensure the
1964     // accesses to the None check, which would error, are not compiled.
1965     if (cond_value.staticIf()) {
1966       if (*cond_value.staticIf()) {
1967         insertRefinements(loc, cond_value.refinements());
1968         emitStatements(trueBranch);
1969       } else {
1970         insertRefinements(loc, cond_value.refinements().Not());
1971         emitStatements(falseBranch);
1972       }
1973       return;
1974     }
1975 
1976     Node* n = graph->insertNode(create(prim::If, loc, 0));
1977     n->addInput(cond_value.value());
1978     auto* true_block = n->addBlock();
1979     auto* false_block = n->addBlock();
1980 
1981     // Emit both blocks once to get the union of all mutated values
1982     auto save_true =
1983         emitSingleIfBranch(true_block, trueBranch, cond_value.refinements());
1984     auto save_false = emitSingleIfBranch(
1985         false_block, falseBranch, cond_value.refinements().Not());
1986 
1987     bool true_exits = exit_blocks.count(true_block);
1988     bool false_exits = exit_blocks.count(false_block);
1989     if (true_exits && false_exits) {
1990       exit_blocks.insert(n->owningBlock());
1991     }
1992 
1993     // In python, every variable assigned in an if statement escapes
1994     // the scope of the if statement (all variables are scoped to the function).
1995     // Script is a subset of python: we consider variables to be in scope
1996     // as long as there is a definition of the variable along all paths
1997     // through the if statement
1998     // ----
1999     // if ...:
2000     //   a =
2001     // else:
2002     //   ...
2003     // ... = a  # error, a is not defined along all paths
2004     // ----
2005     // if ...:
2006     //   a =
2007     // else:
2008     //   a =
2009     // ... = a # OK, a is defined along all paths
2010     // ----
2011     // a = ...
2012     // if ...:
2013     //   a =
2014     // ... = a # OK, a is defined along all paths
2015     // if ...:
2016     //   a =
2017     // else:
2018     //   return
2019     // ... = a # OK, a is always defined
2020 
2021     // ordered set, because we want deterministic graph output
2022     std::set<std::string> mutated_variables;
2023 
2024     // When we access either the true or false environment,
2025     // we need to set the insertion point so the prim::Load is inserted
2026     // into the right block.
2027     // if var is only defined in one branch save error in case it's used later
2028     for (auto& v : save_true->definedVariables()) {
2029       {
2030         WithInsertPoint insert(false_block);
2031         if (save_false->findInAnyFrame(v) || false_exits) {
2032           mutated_variables.insert(v);
2033         } else {
2034           if (reportSourceLocation(loc.source()->size())) {
2035             ErrorReport error(loc);
2036             environment_stack->setVariableTypeError(v, [=]() -> std::string {
2037               error << v << " is not defined in the false branch";
2038               return error.what();
2039             });
2040           } else {
2041             environment_stack->setVariableTypeError(v, [=]() -> std::string {
2042               std::stringstream ss;
2043               ss << v << " is not defined in the false branch. "
2044                  << "The source info is eliminated due to the source file is too large. "
2045                  << "To get it back, please set PYTORCH_JIT_ENABLE_LARGE_SOURCE_LOCATION=1 "
2046                  << "as env var";
2047               return ss.str();
2048             });
2049           }
2050         }
2051       }
2052     }
2053     for (auto& v : save_false->definedVariables()) {
2054       {
2055         WithInsertPoint insert(true_block);
2056         if (save_true->findInAnyFrame(v) || true_exits) {
2057           mutated_variables.insert(v);
2058         } else {
2059           if (reportSourceLocation(loc.source()->size())) {
2060             ErrorReport error(loc);
2061             environment_stack->setVariableTypeError(v, [=]() -> std::string {
2062               error << v << " is not defined in the true branch";
2063               return error.what();
2064             });
2065           } else {
2066             environment_stack->setVariableTypeError(v, [=]() -> std::string {
2067               std::stringstream ss;
2068               ss << v << " is not defined in the false branch. "
2069                  << "The source info is eliminated due to the source file is too large. "
2070                  << "To get it back, please set PYTORCH_JIT_ENABLE_LARGE_SOURCE_LOCATION=1 "
2071                  << "as env var";
2072               return ss.str();
2073             });
2074           }
2075         }
2076       }
2077     }
2078 
2079     // Register outputs in each block
2080     for (const auto& x : mutated_variables) {
2081       Value* tv = nullptr;
2082       Value* fv = nullptr;
2083 
2084       {
2085         WithInsertPoint insert(true_block);
2086         if (!true_exits) {
2087           tv = save_true->getVar(x, loc);
2088         }
2089       }
2090       {
2091         WithInsertPoint insert(false_block);
2092         if (!false_exits) {
2093           fv = save_false->getVar(x, loc);
2094         }
2095       }
2096 
2097       // if both branches exit don't emit any variables
2098       // if one branch exits then we allow the all variables in the other branch
2099       // to escape scope since they are well-defined
2100       if (true_exits && false_exits) {
2101         continue;
2102       } else if (true_exits) {
2103         tv = graph->createUninitialized(fv->type())
2104                  ->insertBefore(true_block->return_node())
2105                  ->output();
2106         graph->createStore(x, tv)->insertBefore(true_block->return_node());
2107       } else if (false_exits) {
2108         fv = graph->createUninitialized(tv->type())
2109                  ->insertBefore(false_block->return_node())
2110                  ->output();
2111         graph->createStore(x, fv)->insertBefore(false_block->return_node());
2112       }
2113 
2114       SugaredValuePtr maybe_sugared_x = environment_stack->findInAnyFrame(x);
2115       TypePtr full_type = nullptr;
2116       if (maybe_sugared_x) {
2117         Value* maybe_simple = asSimple(maybe_sugared_x);
2118         if (maybe_simple) {
2119           full_type = maybe_simple->type();
2120         }
2121       }
2122 
2123       // Try to unify the types. If we found a type annotation earlier
2124       // in the environment, and if that type annotation is some form
2125       // of union, then we need to tell `unifyTypes` not to throw an
2126       // error if the branched return types we found are heterogenous
2127       bool default_to_union = full_type &&
2128           (full_type->kind() == UnionType::Kind ||
2129            full_type->kind() == OptionalType::Kind ||
2130            full_type->kind() == NumberType::Kind);
2131       auto unified = unifyTypes(
2132           tv->type(), fv->type(), /*default_to_union=*/default_to_union);
2133 
2134       // We allow variables to be set to different types in each branch
2135       // as long as that variable is not already in scope or if that
2136       // variable does not get used later. Here, we save the error so
2137       // that the error message will be more informative in the case
2138       // that is used later. When `a` is accessed in `(a + 1)`, the
2139       // error will get printed:
2140       // if cond:
2141       //    a = 1
2142       // else:
2143       //    a = tensor
2144       // b = a + 1
2145       //
2146       if (!unified) {
2147         ErrorReport error(loc);
2148         error << "Type mismatch: " << x << " is set to type "
2149               << tv->type()->repr_str() << " in the true branch"
2150               << " and type " << fv->type()->repr_str()
2151               << " in the false branch";
2152         if (save_true->findInParentFrame(x) ||
2153             save_false->findInParentFrame(x)) {
2154           throw ErrorReport(error);
2155         } else {
2156           environment_stack->setVariableTypeError(
2157               x, [=]() -> std::string { return error.what(); });
2158           continue;
2159         }
2160       }
2161       environment_stack->setType(x, *unified);
2162     }
2163   }
2164 
emitHasAttrtorch::jit::to_ir2165   CondValue emitHasAttr(const Expr& objExpr, const Expr& attrExpr) {
2166     auto obj = emitSugaredExpr(objExpr, 1);
2167     if (attrExpr.kind() != TK_STRINGLITERAL) {
2168       throw(
2169           ErrorReport(attrExpr)
2170           << "hasattr's second argument must be a string literal");
2171     }
2172     const std::string& name = StringLiteral(attrExpr).text();
2173     const bool hasAttr = obj->hasAttr(objExpr.range(), method, name);
2174     return CondValue(*graph, objExpr.range(), hasAttr, {});
2175   }
2176 
emitIsInstancetorch::jit::to_ir2177   CondValue emitIsInstance(const Expr& obj, const Expr& classinfo) {
2178     Value* lhs_val = emitExpr(obj);
2179     std::vector<TypePtr> lhs_types;
2180     std::vector<TypePtr> rhs_types;
2181 
2182     std::function<void(const Expr&)> gather_rhs = [&](const Expr& expr) {
2183       if (expr.kind() == TK_TUPLE_LITERAL) {
2184         for (Expr e : TupleLiteral(expr).inputs()) {
2185           gather_rhs(e);
2186         }
2187         return;
2188       }
2189       TypePtr type = typeParser_.parseTypeFromExpr(expr);
2190       rhs_types.emplace_back(type);
2191     };
2192 
2193     lhs_types.push_back(lhs_val->type());
2194     gather_rhs(classinfo);
2195 
2196     standardizeVectorForUnion(&lhs_types);
2197     standardizeVectorForUnion(&rhs_types);
2198 
2199     RefinementSet refinement;
2200 
2201     TypePtr unified_true = nullptr;
2202     TypePtr unified_false = nullptr;
2203 
2204     std::vector<TypePtr> isinstance_types;
2205     std::vector<TypePtr> not_isinstance_types;
2206 
2207     std::vector<Refinement> true_refinements;
2208     std::vector<Refinement> false_refinements;
2209 
2210     bool all_lhs_subtype_some_rhs = true;
2211 
2212     // We can discard any rhs types that we know statically would be
2213     // impossible. For example, if we had:
2214     //
2215     //    def fn(x: Optional[str]):
2216     //        if isinstance(x, (List[str], str, int)):
2217     //            ...
2218     //
2219     // then `x` would be `str` in the true branch and `None` in the
2220     // false branch, not `(List[str], str, int)` in the true branch
2221     // and `None` in the false branch
2222     for (const TypePtr& lhs_type : lhs_types) {
2223       if (lhs_type == AnyType::get()) {
2224         isinstance_types.insert(
2225             isinstance_types.end(), rhs_types.begin(), rhs_types.end());
2226         not_isinstance_types.emplace_back(AnyType::get());
2227         // Edge case: we can still say that all lhs types subtype some
2228         // rhs type if `lhs` is `Any` and `rhs` is `Any`
2229         if (isinstance_types.size() != 1 ||
2230             isinstance_types[0] != AnyType::get()) {
2231           all_lhs_subtype_some_rhs = false;
2232         }
2233         break;
2234       }
2235 
2236       auto get_smaller_type = [&](const TypePtr& t1,
2237                                   const TypePtr& t2) -> TypePtr {
2238         if (t1->isSubtypeOf(*t2)) {
2239           return t1;
2240         } else if (t2->isSubtypeOf(*t1)) {
2241           return t2;
2242         } else {
2243           return nullptr;
2244         }
2245       };
2246 
2247       TypePtr found_refinement = nullptr;
2248       for (const TypePtr& rhs_type : rhs_types) {
2249         TypePtr maybe_smaller_type = get_smaller_type(lhs_type, rhs_type);
2250         if (!maybe_smaller_type) {
2251           continue;
2252         } else if (*maybe_smaller_type == *lhs_type) {
2253           // Cover the case that we have something like
2254           // lhs = `List[str]` and rhs = `list`
2255           found_refinement = lhs_type;
2256         } else if (*maybe_smaller_type == *rhs_type) {
2257           // We want the narrowest possible type
2258           found_refinement = found_refinement
2259               ? *(unifyTypes(found_refinement, rhs_type))
2260               : rhs_type;
2261         }
2262       }
2263 
2264       if (found_refinement) {
2265         if (*found_refinement == *lhs_type) {
2266           all_lhs_subtype_some_rhs &= true;
2267         }
2268         isinstance_types.push_back(found_refinement);
2269       } else {
2270         // If the lhs couldn't be a subtype of the rhs (or couldn't
2271         // be "refined" to itself, as in the `List[str]` and `list`
2272         // case above), then we add `lhs_type` to the false branch
2273         // refinements. This is because the type can still be itself
2274         // if the `isinstance` check is false
2275         not_isinstance_types.push_back(lhs_type);
2276         all_lhs_subtype_some_rhs = false;
2277       }
2278     }
2279 
2280     // For use with `unifyTypeList`
2281     std::stringstream nowhere;
2282 
2283     // Get a single type for the true and false branches
2284     if (!isinstance_types.empty()) {
2285       unified_true =
2286           *unifyTypeList(isinstance_types, nowhere, /*default_to_union=*/true);
2287     }
2288     if (obj.kind() == TK_VAR && unified_true) {
2289       std::string ident = Var(obj).name().name();
2290       true_refinements = {Refinement(ident, unified_true)};
2291     }
2292 
2293     // Get a single type for the true and false branches
2294     if (!not_isinstance_types.empty()) {
2295       unified_false = *unifyTypeList(
2296           not_isinstance_types, nowhere, /*default_to_union=*/true);
2297     }
2298     if (obj.kind() == TK_VAR && unified_false) {
2299       std::string ident = Var(obj).name().name();
2300       false_refinements = {Refinement(ident, unified_false)};
2301     }
2302 
2303     refinement = RefinementSet(true_refinements, false_refinements);
2304 
2305     bool is_statically_false = isinstance_types.empty();
2306 
2307     // If the statement is statically true
2308     if (all_lhs_subtype_some_rhs) {
2309       return CondValue(*graph, obj.range(), true, std::move(refinement));
2310     }
2311 
2312     if (is_statically_false) {
2313       return CondValue(*graph, obj.range(), false, std::move(refinement));
2314     }
2315 
2316     // check maybe true/false at runtime, need an actual op
2317     Value* result =
2318         graph->insertNode(graph->createIsInstance(lhs_val, rhs_types))
2319             ->output();
2320     return CondValue(result, std::move(refinement), std::nullopt);
2321   }
2322 
emitIftorch::jit::to_ir2323   void emitIf(const If& stmt) {
2324     Expr cond = stmt.cond();
2325     CondValue cond_value = emitCondExpr(cond);
2326     emitIfElseBlocks(
2327         stmt.range(), cond_value, stmt.trueBranch(), stmt.falseBranch());
2328   }
2329 
2330   // *********************** Loop Operators ************************************
2331   // Emits a loop operator with the form:
2332   // Loop(max_trip_count)
2333   // block0(loop_counter) {
2334   //   <body>
2335   // }
2336   // block1 {
2337   //   <loop condition>
2338   //   -> (condition)
2339   // }
2340   // For loops will have an empty loop condition block with condition set to
2341   // true. In the convert to ssa pass, the loop condition will correctly
2342   // inlined. and inputs and outputs added so that the loop conforms to the
2343   // semantics specified at
2344   // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Loop
emitLoopCommontorch::jit::to_ir2345   void emitLoopCommon(
2346       const SourceRange& range,
2347       const std::function<void()>& emit_body,
2348       const SugaredValuePtr& iter_val,
2349       std::optional<List<Expr>> targets,
2350       std::optional<Expr> cond) {
2351     Value* max_trip_count_val = nullptr;
2352     if (iter_val != nullptr) {
2353       max_trip_count_val = iter_val->len(range, method);
2354     } else {
2355       max_trip_count_val = materializeConstant(
2356           std::numeric_limits<int64_t>::max(),
2357           *graph,
2358           range,
2359           integral_constants);
2360     }
2361 
2362     Node* n = graph->insertNode(create(prim::Loop, range, 0));
2363     auto* body_block = n->addBlock();
2364     {
2365       Block* condition_block = n->addBlock();
2366       pushFrame(condition_block);
2367       Value* out = nullptr;
2368       if (cond) {
2369         WithInsertPoint insert(condition_block);
2370         out = emitToBool(cond.value().range(), emitExpr(cond.value()));
2371       } else {
2372         WithInsertPoint insert(n);
2373         out = graph->insertConstant(true, range);
2374       }
2375       condition_block->registerOutput(out);
2376       popFrame();
2377     }
2378     n->addInput(max_trip_count_val);
2379 
2380     WithLoopStatus loop_guard(&loop_status_, LoopStatus::IN_LOOP);
2381     Value* trip_count =
2382         body_block->addInput()->setType(IntType::get()); // Iteration num
2383     {
2384       pushFrame(body_block);
2385       WithInsertPoint guard(body_block);
2386 
2387       // if the FOR iters and targets are present, emit FOR target assignments
2388       if (iter_val != nullptr && targets) {
2389         Value* cur_elem = iter_val->getitem(range, method, trip_count)
2390                               ->asValue(range, method);
2391         SugaredValuePtr sv = std::make_shared<SimpleValue>(cur_elem);
2392         List<Expr> target_exprs = targets.value();
2393         validateAssignLhsExpr(target_exprs, range);
2394 
2395         // if target exprs are more than 1, it means iteration unpacking on LHS
2396         // we create Tuple literal to wrap those target exprs for assignments
2397         if (target_exprs.size() > 1) {
2398           Expr tl = TupleLiteral::create(range, target_exprs);
2399           target_exprs = List<Expr>::create(range, {tl});
2400         }
2401         emitExprsAssign(target_exprs, {sv}, range, /*n_binders=*/1);
2402       }
2403       emit_body();
2404       popFrame();
2405     }
2406   }
2407 
emitUnrolledLooptorch::jit::to_ir2408   void emitUnrolledLoop(
2409       const SourceRange& loc,
2410       const std::function<void()>& emit_body,
2411       const SugaredValuePtr& iterable,
2412       const List<Expr>& targets) {
2413     auto static_len = iterable->staticLen();
2414     TORCH_INTERNAL_ASSERT(
2415         static_len, "Unrolled loop iter should have static length");
2416     int64_t len = *static_len;
2417     WithLoopStatus loop_guard(&loop_status_, LoopStatus::IN_UNROLLED_LOOP);
2418     // In order to support ModuleLists which return different types,
2419     // as with an nn.Sequential which has a module that returns a Dict and then
2420     // a module which returns a Tensor,
2421     // we do not push a new environment frame because if we did all intermediary
2422     // values would have to subtype the input type.
2423     for (const auto i : c10::irange(len)) {
2424       auto index =
2425           materializeConstant(i, *method.graph(), loc, integral_constants);
2426       auto sugared_value = iterable->getitem(loc, method, index);
2427       emitExprsAssign(
2428           targets, {sugared_value}, targets.range(), /*n_binders=*/1);
2429       emit_body();
2430     }
2431   }
2432 
emitFortorch::jit::to_ir2433   void emitFor(
2434       const List<Expr>& targets,
2435       const List<Expr>& itrs,
2436       const SourceRange& loc,
2437       const std::function<void()>& emit_body) {
2438     if (itrs.size() != 1) {
2439       throw(ErrorReport(loc) << "List of iterables is not supported currently");
2440     }
2441 
2442     // Emit loop information for builtinFunction values like range(), zip(),
2443     // enumerate() or SimpleValue like List, Tensor, Dict, etc.
2444     SugaredValuePtr sv = emitSugaredExpr(itrs[0], 1);
2445     SugaredValuePtr iterable = sv->iter(loc, method);
2446 
2447     // We unroll the loop for iterables that contain ModuleLists so that we can
2448     // compile Heterogenous module lists.
2449     if (!iterable->shouldEmitUnrolled()) {
2450       emitLoopCommon(loc, emit_body, iterable, targets, {});
2451     } else {
2452       emitUnrolledLoop(loc, emit_body, iterable, targets);
2453     }
2454   }
2455 
emitFortorch::jit::to_ir2456   void emitFor(const For& stmt) {
2457     auto emit_body = [&]() { emitStatements(stmt.body()); };
2458     emitFor(stmt.targets(), stmt.itrs(), stmt.range(), emit_body);
2459   }
2460 
emitWhiletorch::jit::to_ir2461   void emitWhile(const While& stmt) {
2462     auto cond = stmt.cond();
2463     auto emit_body = [&]() { emitStatements(stmt.body()); };
2464     emitLoopCommon(stmt.range(), emit_body, nullptr, {}, cond);
2465   }
2466 
emitWithtorch::jit::to_ir2467   void emitWith(const With& stmt) {
2468     auto targets = stmt.targets();
2469     // Keep a stack of entered objects so they can be exited
2470     // in the right order.
2471     std::stack<Value*> entered;
2472 
2473     for (const auto& target : targets) {
2474       Expr e = target.target();
2475 
2476       auto* rhs = emitExpr(e);
2477       auto* n = graph->insertNode(graph->create(prim::Enter, {rhs}));
2478       entered.push(rhs);
2479 
2480       if (rhs->type()->kind() != TypeKind::ClassType) {
2481         throw(
2482             ErrorReport(e.range())
2483             << "With item expression must return an object");
2484       }
2485 
2486       auto rhsClass = rhs->type()->expect<ClassType>();
2487       auto* enterMethod = rhsClass->findMethod("__enter__");
2488       auto* exitMethod = rhsClass->findMethod("__exit__");
2489 
2490       if (!enterMethod || !exitMethod) {
2491         throw(
2492             ErrorReport(e.range())
2493             << "Object returned by with item expression does not define __enter__ and __exit__ methods");
2494       }
2495 
2496       // Check the schema of __enter__.
2497       auto& enterSchema = enterMethod->getSchema();
2498       if (enterSchema.arguments().size() != 1) {
2499         throw(
2500             ErrorReport(e.range())
2501             << "__enter__ must have only one argument and one return value");
2502       }
2503 
2504       // Check the schema of __exit__.
2505       auto& exitSchema = exitMethod->getSchema();
2506       if (exitSchema.arguments().size() != 4) {
2507         throw(ErrorReport(e.range()) << "__exit__ must have four arguments");
2508       } else {
2509         for (unsigned i = 1; i < 4; ++i) {
2510           if (exitSchema.arguments().at(i).type() != AnyType::get()) {
2511             throw(
2512                 ErrorReport(e.range())
2513                 << "argument " << i
2514                 << " of __exit__ must have Any type; TorchScript does not currently support passing exception type, value, or traceback to the __exit__ function.");
2515           }
2516         }
2517       }
2518 
2519       // Set the output of the enter node to be the return type of __enter__.
2520       n->output(0)->setType(enterSchema.returns().at(0).type());
2521 
2522       // Set i = e.__enter__() so that references to i in the body of the with
2523       // will resolve correctly.
2524       if (target.var().present()) {
2525         Var i = target.var().get();
2526         environment_stack->setVar(i.range(), i.name().name(), n->output(0));
2527       }
2528     }
2529 
2530     emitStatements(stmt.body());
2531 
2532     // Insert all the corresponding prim::Exit nodes.
2533     while (!entered.empty()) {
2534       auto* input = entered.top();
2535       entered.pop();
2536       auto* n = graph->create(prim::Exit);
2537       graph->insertNode(n);
2538       n->addInput(input);
2539     }
2540   }
2541 
2542   // Currently we do not support assigning exceptions to variables,
2543   // a = Exception("hi")
2544   // raise a
2545   //
2546   // We ignore the expression following raise
emitRaisetorch::jit::to_ir2547   void emitRaise(const Raise& raise) {
2548     auto sv = emitSugaredExpr(raise.expr(), 1);
2549     Value* error_message = nullptr;
2550     Value* qualified_class_name = nullptr;
2551 
2552     if (auto exception_instance =
2553             std::dynamic_pointer_cast<ExceptionMessageValue>(sv)) {
2554       // The typical case, an instance of the exception class was thrown:
2555       //    raise RuntimeError("error")
2556       error_message = exception_instance->getValue();
2557       qualified_class_name = exception_instance->getQualifiedClassName();
2558     } else if (
2559         auto exception_class = std::dynamic_pointer_cast<ExceptionValue>(sv)) {
2560       // A bare exception was thrown so add an empty message. e.g.
2561       //    raise RuntimeError
2562       error_message = insertConstant(*graph, "", raise.range());
2563     } else {
2564       // The raise was not followed by an exception (i.e. it was something like
2565       // `raise "error"` instead of `raise RuntimeError("error")`)
2566       throw(
2567           ErrorReport(raise.range())
2568           << "exceptions must derive from BaseException");
2569     }
2570 
2571     if (!error_message->type()->isSubtypeOf(*StringType::get())) {
2572       error_message = graph->insert(aten::str, {error_message});
2573     }
2574 
2575     graph->insert(
2576         prim::RaiseException,
2577         {error_message, qualified_class_name},
2578         {},
2579         raise.range());
2580     exit_blocks.insert(environment_stack->block());
2581   }
2582 
2583   // emit assserions as an if branch so that assertions will reuse the
2584   // message
emitAsserttorch::jit::to_ir2585   void emitAssert(const Assert& stmt) {
2586     CondValue cond_value = emitCondExpr(stmt.test());
2587     List<Stmt> true_branch = List<Stmt>::create(stmt.range(), {});
2588     // Create an `AssertionError("the_message")` call
2589     auto message = (stmt.msg().present())
2590         ? stmt.msg().get()
2591         : StringLiteral::create(stmt.range(), "");
2592     auto callee = Var::create(
2593         stmt.range(), Ident::create(stmt.range(), "AssertionError"));
2594     auto apply = Apply::create(
2595         stmt.range(),
2596         callee,
2597         List<Expr>::create(stmt.range(), {message}),
2598         List<Attribute>::create(stmt.range(), {}));
2599 
2600     List<Stmt> false_branch =
2601         List<Stmt>::create(stmt.range(), {Raise::create(stmt.range(), apply)});
2602     emitIfElseBlocks(stmt.range(), cond_value, true_branch, false_branch);
2603   }
2604 
2605   // Validate that the `lhs` Expr's in an assignment statement are valid. That
2606   // is:
2607   //
2608   // 1) All lhs Expr's are either Var, Tuple or Starred nodes
2609   // 2) There is at most one Starred node in the lhs Expr
2610   // 3) A Starred node can only appear when there is another non-Starred lhs
2611   //    Expr. Concretely this means that `*abc = func()` is illegal. Unpacking
2612   //    all outputs into a tuple is covered by `abc = func()`.
validateAssignLhsExprtorch::jit::to_ir2613   bool validateAssignLhsExpr(const List<Expr>& lhs, const SourceRange& r) {
2614     size_t num_normal_assign = 0;
2615     size_t num_starred = 0;
2616     for (const auto& assignee : lhs) {
2617       if (assignee.kind() == TK_VAR || assignee.kind() == TK_SUBSCRIPT ||
2618           assignee.kind() == TK_TUPLE_LITERAL || assignee.kind() == '.') {
2619         num_normal_assign++;
2620       } else if (assignee.kind() == TK_STARRED) {
2621         num_starred++;
2622       } else {
2623         throw(
2624             ErrorReport(assignee) << "lhs of assignment must be a variable, "
2625                                   << "subscript, or starred expression");
2626       }
2627     }
2628 
2629     if (num_starred > 1) {
2630       throw(
2631           ErrorReport(r)
2632           << "Only one starred expression is allowed on the lhs");
2633     }
2634 
2635     if (num_starred > 0 && num_normal_assign == 0) {
2636       throw(
2637           ErrorReport(r) << "A Starred expression may only appear on the "
2638                          << "lhs within the presence of another non-starred"
2639                          << " expression");
2640     }
2641 
2642     return num_starred;
2643   }
2644 
2645   // Get the appropriate builtin op for this augmented assignment
2646   // If the RHS is a tensor, return the corresponding ATen in-place op
2647   // If it's a list of scalars, then return the corresponding list augment op
getAugOptorch::jit::to_ir2648   Symbol getAugOp(const AugAssign& stmt, const TypePtr& type) {
2649     bool use_inplace_op = type->isSubtypeOf(*TensorType::get()) ||
2650         type->kind() == TypeKind::ListType;
2651     switch (stmt.aug_op()) {
2652       case '+':
2653         return use_inplace_op ? aten::add_ : aten::add;
2654       case '-':
2655         return use_inplace_op ? aten::sub_ : aten::sub;
2656       case '/':
2657         return use_inplace_op ? aten::div_ : aten::div;
2658       case '*':
2659         return use_inplace_op ? aten::mul_ : aten::mul;
2660       case '%':
2661         return use_inplace_op ? aten::fmod_ : aten::fmod;
2662       case '|':
2663         return use_inplace_op ? aten::bitwise_or : aten::__or__;
2664       case '&':
2665         return use_inplace_op ? aten::bitwise_and : aten::__and__;
2666       case '^':
2667         return use_inplace_op ? aten::bitwise_xor : aten::__xor__;
2668       case TK_LSHIFT:
2669         return use_inplace_op ? aten::__ilshift__ : aten::__lshift__;
2670       case TK_RSHIFT:
2671         return use_inplace_op ? aten::__irshift__ : aten::__rshift__;
2672       case TK_POW:
2673         return aten::pow;
2674       default:
2675         throw(
2676             ErrorReport(stmt)
2677             << "Unknown augmented assignment: " << kindToString(stmt.aug_op()));
2678     }
2679   }
2680 
2681   // Get a pair of <in place magic method name, out of place magic method name>
2682   // since the out of place method is called if the in place method is not
2683   // present
getAugMagicMethodtorch::jit::to_ir2684   std::pair<std::string, std::string> getAugMagicMethod(const AugAssign& stmt) {
2685     switch (stmt.aug_op()) {
2686       case '+':
2687         return std::make_pair(std::string("__iadd__"), std::string("__add__"));
2688       case '-':
2689         return std::make_pair(std::string("__isub__"), std::string("__sub__"));
2690       case '/':
2691         return std::make_pair(
2692             std::string("__itruediv__"), std::string("__truediv__"));
2693       case '*':
2694         return std::make_pair(std::string("__imul__"), std::string("__mul__"));
2695       case '%':
2696         return std::make_pair(std::string("__imod__"), std::string("__mod__"));
2697       default:
2698         throw(
2699             ErrorReport(stmt)
2700             << "Unknown augmented assignment: " << kindToString(stmt.aug_op()));
2701     }
2702   }
2703 
2704   // Emit nodes for augmented assignments like `+=`
emitAugAssignmenttorch::jit::to_ir2705   void emitAugAssignment(const AugAssign& stmt) {
2706     switch (stmt.lhs().kind()) {
2707       case TK_VAR: {
2708         emitAugAssignmentToVar(stmt);
2709       } break;
2710       case '.': {
2711         emitAugAssignmentToSelectVar(stmt);
2712       } break;
2713       case TK_SUBSCRIPT: {
2714         emitAugAssignmentToSubscript(stmt);
2715       } break;
2716       default:
2717         throw(
2718             ErrorReport(stmt.lhs())
2719             << "unexpected expression on "
2720             << "left-hand side of augmented assignment");
2721     }
2722   }
2723 
2724   // This will be called when there is a class param or module buffer
2725   // mutation which make the LHS of the expr be a select expression
2726   //
2727   // Example like:
2728   // class A(Module):
2729   //  def __init__():
2730   //    self.register_buffer("running_var", torch.zeros(1))
2731   //
2732   //  def forward():
2733   //    self.num_batches += 1
emitAugAssignmentToSelectVartorch::jit::to_ir2734   void emitAugAssignmentToSelectVar(const AugAssign& stmt) {
2735     const auto lhs = Select(stmt.lhs());
2736     auto lhsSugaredVar = emitSugaredExpr(lhs.value(), 1);
2737     const auto lhsValue =
2738         lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())
2739             ->asValue(lhs.range(), method);
2740     auto result = emitAugAssignmentHelper(stmt, lhsValue);
2741     lhsSugaredVar->setAttr(stmt.range(), method, lhs.selector().name(), result);
2742   }
2743 
emitAugAssignmentToVartorch::jit::to_ir2744   void emitAugAssignmentToVar(const AugAssign& stmt) {
2745     const auto lhs = Var(stmt.lhs());
2746     auto lhsValue = emitExpr(lhs);
2747     auto result = emitAugAssignmentHelper(stmt, lhsValue);
2748     environment_stack->setVar(lhs.range(), lhs.name().name(), result);
2749   }
2750 
emitAugAssignmentHelpertorch::jit::to_ir2751   Value* emitAugAssignmentHelper(const AugAssign& stmt, Value* lhs) {
2752     if (lhs->type()->kind() == TypeKind::ClassType) {
2753       // Call `__iadd__` so updates happen in place on class types
2754       // https://docs.python.org/3/reference/datamodel.html#object.__iadd__
2755       std::string in_place_method_name;
2756       std::string out_of_place_method_name;
2757       std::tie(in_place_method_name, out_of_place_method_name) =
2758           getAugMagicMethod(stmt);
2759       const auto rhs = emitExpr(stmt.rhs());
2760 
2761       // Determine whether to use __iadd__ or __add__ (use __add__ only if
2762       // __iadd__ is not present)
2763       auto type = lhs->type()->expect<ClassType>();
2764       std::string magic_method_name;
2765       if (type->findMethod(in_place_method_name)) {
2766         magic_method_name = in_place_method_name;
2767       } else if (type->findMethod(out_of_place_method_name)) {
2768         magic_method_name = out_of_place_method_name;
2769       } else {
2770         throw(
2771             ErrorReport(stmt.range())
2772             << "Cannot emit inplace op on " << type->repr_str()
2773             << " since it does not define an " << in_place_method_name << " or "
2774             << out_of_place_method_name << " method");
2775       }
2776 
2777       // x += y is equivalent to x = x.__iadd__(y) or x = x.__add__(y) if
2778       // __iadd__ is not present
2779       return MethodValue(lhs, magic_method_name)
2780           .call(stmt.range(), method, {rhs}, {}, 0)
2781           ->asValue(stmt.range(), method);
2782     } else {
2783       const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()))
2784                            .value(*method.graph());
2785       return emitBuiltinCall(
2786           stmt.range(),
2787           *method.graph(),
2788           getAugOp(stmt, lhs->type()),
2789           /*args=*/{lhs, rhs},
2790           /*kwargs=*/{},
2791           /*self=*/std::nullopt);
2792     }
2793   }
2794 
emitAugAssignmentGenerictorch::jit::to_ir2795   void emitAugAssignmentGeneric(
2796       const AugAssign& stmt,
2797       const Subscript& lhs,
2798       Value* sliceable) {
2799     // Get the idx to augment
2800     const auto subscriptExprs = lhs.subscript_exprs();
2801     const TypePtr type = sliceable->type();
2802     if (subscriptExprs.size() != 1) {
2803       throw(
2804           ErrorReport(subscriptExprs)
2805           << "Sliced expression not yet supported for " << type->repr_str()
2806           << " augmented assignment. "
2807           << "File a bug if you want this");
2808     }
2809 
2810     TypePtr elemType = nullptr;
2811     if (const ListTypePtr listType = type->cast<ListType>()) {
2812       elemType = listType->getElementType();
2813     } else if (const DictTypePtr dictType = type->cast<DictType>()) {
2814       elemType = dictType->getKeyType();
2815     }
2816 
2817     if (elemType == nullptr) {
2818       throw(
2819           ErrorReport(lhs) << type->repr_str()
2820                            << " does not support augmented assignment.");
2821     }
2822     const auto idxValue = emitExpr(subscriptExprs[0]);
2823     const auto containerArg =
2824         NamedValue(lhs.value().range(), type->str(), sliceable);
2825     const auto idxArg = NamedValue(subscriptExprs.range(), "idx", idxValue);
2826     const auto valueArg =
2827         NamedValue(stmt.rhs().range(), "value", emitExpr(stmt.rhs()));
2828 
2829     const auto getItem = graph->insert(
2830         aten::__getitem__, {containerArg, idxArg}, {}, stmt.range());
2831     const auto augmentedItem = graph->insert(
2832         getAugOp(stmt, elemType), {getItem, valueArg}, {}, stmt.range());
2833     graph->insert(
2834         aten::_set_item,
2835         {containerArg, idxArg, augmentedItem},
2836         {},
2837         stmt.range());
2838   }
2839 
emitAugAssignmentToSubscripttorch::jit::to_ir2840   void emitAugAssignmentToSubscript(const AugAssign& stmt) {
2841     // Process the base list value
2842     const auto lhs = Subscript(stmt.lhs());
2843     const auto sliceable = emitExpr(lhs.value());
2844 
2845     if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
2846       // If it's a tensor, just fully evaluate the subscript operation and emit
2847       // an in-place assignment
2848       auto [sliced, tensorIndices] = emitIntAndSliceIndexing(
2849           lhs.range(), sliceable, lhs.subscript_exprs());
2850 
2851       const auto slicedArg = NamedValue(stmt.lhs().range(), "self", sliced);
2852       const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
2853       if (tensorIndices.empty()) {
2854         // Common case: we only tried to index with int and slices. Emit the
2855         // correct augmented assignment op to the sliced value
2856         emitBuiltinCall(
2857             stmt.range(),
2858             *method.graph(),
2859             getAugOp(stmt, sliceable->type()),
2860             {rhs},
2861             {},
2862             slicedArg);
2863       } else {
2864         // Special case: we tried to do "advanced indexing". Lower this expr
2865         // into `index` and `index_put_` ops with tensordices of Tensor?[]
2866         const auto indices = graph
2867                                  ->insertNode(graph->createList(
2868                                      OptionalType::ofTensor(), tensorIndices))
2869                                  ->output();
2870         const auto indexed =
2871             graph->insert(aten::index, {slicedArg, indices}, {}, stmt.range());
2872         const auto augmented = emitBuiltinCall(
2873             stmt.range(),
2874             *method.graph(),
2875             getAugOp(stmt, sliceable->type()),
2876             {rhs},
2877             {},
2878             indexed);
2879         graph->insert(
2880             aten::index_put_,
2881             {slicedArg, indices, augmented},
2882             {},
2883             stmt.range());
2884       }
2885     } else {
2886       emitAugAssignmentGeneric(stmt, lhs, sliceable);
2887     }
2888   }
2889 
emitValueToTensortorch::jit::to_ir2890   NamedValue emitValueToTensor(
2891       const NamedValue& value,
2892       const NamedValue& matchTypeOf) {
2893     // Add implicit conversion of int/float/complex/bool/number types to tensors
2894     // Used in emitSubscriptAssign to convert:
2895     //   `tensor(...)[x] = 99` to `tensor(...)[x] = tensor(99)`
2896     // Mirrors the `valueToTensor` behavior in python_variable_indexing.cpp
2897     const auto kind = value.type()->kind();
2898     if (kind == c10::TypeKind::NumberType || kind == c10::TypeKind::IntType ||
2899         kind == c10::TypeKind::BoolType || kind == c10::TypeKind::FloatType ||
2900         kind == c10::TypeKind::ComplexType) {
2901       auto dtype = graph->insert(prim::dtype, {matchTypeOf}, {});
2902       auto device = graph->insert(prim::device, {matchTypeOf}, {});
2903       auto converted = graph->insert(
2904           aten::tensor,
2905           {value},
2906           {NamedValue("dtype", dtype), NamedValue("device", device)});
2907       return NamedValue(value.loc(), converted);
2908     }
2909 
2910     return value;
2911   }
2912 
2913   // Emit mutating assignments like `foo[0] = bar`
emitSubscriptAssigntorch::jit::to_ir2914   void emitSubscriptAssign(
2915       const SourceRange& stmtRange,
2916       const Subscript& lhs,
2917       const Expr& rhs) {
2918     emitSubscriptAssign(stmtRange, lhs, NamedValue(rhs.range(), emitExpr(rhs)));
2919   }
2920 
emitSubscriptAssigntorch::jit::to_ir2921   void emitSubscriptAssign(
2922       const SourceRange& stmtRange,
2923       const Subscript& lhs,
2924       const NamedValue& rhs) {
2925     // First check the base value.
2926     auto sliceable = emitExpr(lhs.value());
2927 
2928     // If it's a tensor, copy the RHS data into it
2929     if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
2930       // Handle multi-dimensional slicing: first emit int/slice indexing
2931       // TODO: the Python equivalent code has special-cased copy_to
2932       // broadcasting to match NumPy semantics (see PR#4853). We can't
2933       // replicate that without knowing the size of the Tensor; so really that
2934       // code should be moved into the aten function
2935       auto [sliced, tensorIndices] = emitIntAndSliceIndexing(
2936           lhs.range(), sliceable, lhs.subscript_exprs());
2937 
2938       const auto slicedArg = NamedValue(lhs.range(), sliced);
2939 
2940       // rhs must be a tensor, implicitly convert int/float/complex/bool
2941       const auto convertedRhs = emitValueToTensor(rhs, slicedArg);
2942 
2943       if (tensorIndices.empty()) {
2944         // Common case: we only tried to index with int and slices. Copy the
2945         // RHS into the resulting tensor.
2946         graph->insert(aten::copy_, {slicedArg, convertedRhs}, {}, stmtRange);
2947       } else {
2948         // Special case: we tried to do "advanced indexing" with a tensor.
2949         // Dispatch to `aten::index_put_` with tensorindices of Tensor?[]
2950         const auto indices = graph
2951                                  ->insertNode(graph->createList(
2952                                      OptionalType::ofTensor(), tensorIndices))
2953                                  ->output();
2954 
2955         graph->insert(
2956             aten::index_put_,
2957             {slicedArg, indices, convertedRhs},
2958             {},
2959             stmtRange);
2960       }
2961       // Otherwise, this is a list or a classtype.
2962       // Dispatch to aten::_set_item to both select and assign
2963     } else {
2964       const auto subscript = lhs.subscript_exprs();
2965       if (subscript.size() != 1 || subscript[0].kind() == TK_SLICE_EXPR) {
2966         throw(
2967             ErrorReport(subscript) << "Sliced expression not yet supported for"
2968                                    << " subscripted assignment. "
2969                                    << "File a bug if you want this");
2970       }
2971       if (sliceable->type()->isSubtypeOf(*AnyTupleType::get())) {
2972         throw(
2973             ErrorReport(lhs) << sliceable->type()->repr_str()
2974                              << " does not support subscripted assignment");
2975       }
2976 
2977       std::vector<NamedValue> args;
2978       args.emplace_back(lhs.value().range(), "self", sliceable);
2979       args.emplace_back(
2980           lhs.subscript_exprs().range(), "idx", emitExpr(subscript[0]));
2981       args.push_back(rhs);
2982       makeMagic(
2983           "__setitem__",
2984           std::make_shared<BuiltinFunction>(aten::_set_item, std::nullopt))
2985           ->call(stmtRange, method, args, {}, 0);
2986     }
2987   }
2988 
emitTupleAssigntorch::jit::to_ir2989   void emitTupleAssign(const TupleLiteral& tl, const Expr& rhs) {
2990     size_t n_binders = tl.inputs().size();
2991     bool starred_unpack = validateAssignLhsExpr(tl.inputs(), tl.range());
2992     if (starred_unpack)
2993       n_binders--;
2994     auto output = emitSugaredExpr(rhs, n_binders);
2995     emitTupleAssign(tl, output, rhs.range(), n_binders, starred_unpack);
2996   }
2997 
emitTupleAssigntorch::jit::to_ir2998   void emitTupleAssign(
2999       const TupleLiteral& tl,
3000       const SugaredValuePtr& rhs_output,
3001       const SourceRange& rhs_loc,
3002       size_t n_binders,
3003       bool starred_unpack) {
3004     auto outputs = rhs_output->asTuple(
3005         rhs_loc,
3006         method,
3007         starred_unpack ? std::nullopt : std::optional<size_t>{n_binders});
3008     if (outputs.size() < n_binders) {
3009       throw(
3010           ErrorReport(tl) << "need " << (starred_unpack ? "at least " : "")
3011                           << n_binders << " values to unpack but found only "
3012                           << outputs.size());
3013     }
3014     if (outputs.size() > n_binders && !starred_unpack) {
3015       throw(
3016           ErrorReport(tl) << "too many values to unpack: need " << n_binders
3017                           << " but found " << outputs.size());
3018     }
3019 
3020     emitExprsAssign(tl.inputs(), outputs, rhs_loc, n_binders);
3021   }
3022 
emitExprsAssigntorch::jit::to_ir3023   void emitExprsAssign(
3024       const List<Expr>& lhs_exprs,
3025       const at::ArrayRef<SugaredValuePtr> outputs,
3026       const SourceRange& rhs_loc,
3027       size_t n_binders) {
3028     size_t i = 0;
3029     for (auto assignee : lhs_exprs) {
3030       switch (assignee.kind()) {
3031         case TK_SUBSCRIPT:
3032           emitSubscriptAssign(
3033               rhs_loc,
3034               Subscript(assignee),
3035               NamedValue(rhs_loc, outputs.at(i)->asValue(rhs_loc, method)));
3036           i++;
3037           break;
3038         case TK_VAR:
3039           environment_stack->setSugaredVar(
3040               assignee.range(),
3041               Var(assignee).name().name(),
3042               outputs.at(i),
3043               /*annotated_type=*/nullptr);
3044           i++;
3045           break;
3046         case TK_STARRED: {
3047           auto var = Starred(assignee).expr();
3048           if (var.kind() != TK_VAR) {
3049             throw(
3050                 ErrorReport(var) << "Cannot pack a tuple into a non-variable");
3051           }
3052           size_t n_matched = outputs.size() - n_binders;
3053           ArrayRef<std::shared_ptr<SugaredValue>> outputs_ref = outputs;
3054           auto values = fmap(
3055               outputs_ref.slice(i, n_matched),
3056               [&](const std::shared_ptr<SugaredValue>& v) {
3057                 return v->asValue(assignee.range(), method);
3058               });
3059           auto tup = graph->insertNode(graph->createTuple(values))->output();
3060           environment_stack->setVar(var.range(), Var(var).name().name(), tup);
3061           i += n_matched;
3062         } break;
3063         case TK_TUPLE_LITERAL: {
3064           // recursively emit tuple assignments on tuple literal input
3065           TupleLiteral sub_tl = TupleLiteral(assignee);
3066           size_t sub_n_binders = sub_tl.inputs().size();
3067           bool sub_starred_unpack =
3068               validateAssignLhsExpr(sub_tl.inputs(), sub_tl.range());
3069           if (sub_starred_unpack)
3070             sub_n_binders--;
3071           emitTupleAssign(
3072               sub_tl,
3073               outputs.at(i),
3074               rhs_loc,
3075               sub_n_binders,
3076               sub_starred_unpack);
3077           i++;
3078         } break;
3079         case '.': {
3080           emitSelectAssign(assignee, outputs.at(i), rhs_loc);
3081           i++;
3082         } break;
3083         default:
3084           throw(
3085               ErrorReport(assignee)
3086               << "unexpected expression on the left-hand side");
3087       }
3088     }
3089   }
3090 
emitAssignmenttorch::jit::to_ir3091   void emitAssignment(const Assign& stmt) {
3092     if (stmt.lhs_list().size() == 1) {
3093       return emitSingleAssignment(stmt);
3094     }
3095     // multiple assign & annotated type not supported in python
3096     TORCH_INTERNAL_ASSERT(stmt.lhs_list().size() > 1 && !stmt.type().present());
3097     // a = b = expr()
3098     // the semantics of multiple assignment is that expr() is emitted once, then
3099     // from left to right the assignments are made
3100     const auto tmp_name = createTempName("$tmp_assign_");
3101     environment_stack->setSugaredVar(
3102         stmt.rhs().range(),
3103         tmp_name,
3104         emitSugaredExpr(stmt.rhs().get(), 1),
3105         /*annotated_type=*/nullptr);
3106     auto ident = Var::create(
3107         stmt.rhs().range(), Ident::create(stmt.rhs().range(), tmp_name));
3108     for (auto expr : stmt.lhs_list()) {
3109       emitSingleAssignment(Assign::create(
3110           stmt.range(),
3111           List<Expr>::create(expr.range(), {expr}),
3112           Maybe<Expr>::create(stmt.rhs().range(), ident),
3113           Maybe<Expr>::create(stmt.range())));
3114     }
3115   }
3116 
emitSingleAssignmenttorch::jit::to_ir3117   void emitSingleAssignment(const Assign& stmt) {
3118     if (!stmt.rhs().present()) {
3119       throw(
3120           ErrorReport(stmt.range())
3121           << "For an assignment, expected an expression on the right-hand side");
3122     }
3123     const Expr& rhs = stmt.rhs().get();
3124     switch (stmt.lhs().kind()) {
3125       case TK_VAR: {
3126         auto v = Var(stmt.lhs());
3127         TypePtr type = nullptr;
3128         if (stmt.type().present()) {
3129           type = typeParser_.parseTypeFromExpr(stmt.type().get());
3130         }
3131         auto rhs_sugared_val = emitSugaredExpr(rhs, 1, type);
3132         // START BC HACK
3133         //
3134         // For old serialized quantized RNN modules, switch
3135         // quantized::linear_prepack to quantized::linear_prepack_legacy. We
3136         // changed linear_prepack to return a TorchBind class and not a
3137         // cpp_custom_type_hack tensor anymore, but the old serialized models
3138         // are tightly coupled with the type_hack version. If we still create a
3139         // Tensor here, then the quantized_lstm.legacy overload can kick in in
3140         // forward_impl(), and the module will still run correctly.
3141         if (method.qualname() ==
3142             "__torch__.torch.nn.quantized.dynamic.modules.rnn.PackedParameter.__setstate__") {
3143           if (auto sv =
3144                   std::dynamic_pointer_cast<SimpleValue>(rhs_sugared_val)) {
3145             Node* rhs_node = sv->getValue()->node();
3146             if (rhs_node->kind() ==
3147                 Symbol::fromQualString("quantized::linear_prepack")) {
3148               std::vector<NamedValue> inputs;
3149               for (Value* i : rhs_node->inputs()) {
3150                 inputs.emplace_back(i);
3151               }
3152               Value* new_val = rhs_node->owningGraph()->insert(
3153                   Symbol::fromQualString("quantized::linear_prepack_legacy"),
3154                   inputs,
3155                   {},
3156                   rhs_node->sourceRange());
3157               rhs_sugared_val = std::make_shared<SimpleValue>(new_val);
3158             }
3159           }
3160         }
3161         // END BC HACK
3162         environment_stack->setSugaredVar(
3163             v.range(),
3164             v.name().name(),
3165             std::move(rhs_sugared_val),
3166             /*annotated_type=*/type);
3167       } break;
3168       case TK_TUPLE_LITERAL:
3169         emitTupleAssign(TupleLiteral(stmt.lhs()), rhs);
3170         break;
3171       case '.':
3172         emitSelectAssign(stmt);
3173         break;
3174       case TK_SUBSCRIPT:
3175         emitSubscriptAssign(stmt.range(), Subscript(stmt.lhs()), rhs);
3176         break;
3177       default:
3178         throw(
3179             ErrorReport(stmt.lhs())
3180             << "unexpected expression on left-hand side of assignment");
3181     }
3182   }
3183 
emitSelectAssigntorch::jit::to_ir3184   void emitSelectAssign(const Assign& stmt) {
3185     if (!stmt.rhs().present()) {
3186       throw(ErrorReport(stmt.range()) << "Expected RHS for assignment");
3187     }
3188 
3189     TypePtr type_hint = nullptr;
3190     if (stmt.type().present()) {
3191       type_hint = typeParser_.parseTypeFromExpr(stmt.type().get());
3192     }
3193     const auto lhs = Select(stmt.lhs());
3194     auto lhsObject = emitSugaredExpr(lhs.value(), 1);
3195     const auto rhsValue = emitSugaredExpr(stmt.rhs().get(), 1, type_hint)
3196                               ->asValue(stmt.rhs().range(), method);
3197     lhsObject->setAttr(stmt.range(), method, lhs.selector().name(), rhsValue);
3198   }
3199 
emitSelectAssigntorch::jit::to_ir3200   void emitSelectAssign(
3201       const Expr& lhs,
3202       const SugaredValuePtr& rhs,
3203       const SourceRange& loc) {
3204     const auto lhs_select = Select(lhs);
3205     auto lhs_sv = emitSugaredExpr(lhs_select.value(), 1);
3206     const auto rhs_value = rhs->asValue(loc, method);
3207     lhs_sv->setAttr(loc, method, lhs_select.selector().name(), rhs_value);
3208   }
3209 
getNodeKindtorch::jit::to_ir3210   NodeKind getNodeKind(int kind, size_t ninputs) {
3211     switch (kind) {
3212       case '+':
3213         return aten::add;
3214       case '-':
3215         return aten::sub;
3216       case TK_UNARY_MINUS:
3217         return aten::neg;
3218       case '*':
3219         return aten::mul;
3220       case TK_POW:
3221         return aten::pow;
3222       case '@':
3223         return aten::matmul;
3224       case TK_STARRED:
3225         return prim::Starred;
3226       case '/':
3227         return aten::div;
3228       case '%':
3229         return aten::remainder;
3230       case TK_NE:
3231         return aten::ne;
3232       case TK_EQ:
3233         return aten::eq;
3234       case '<':
3235         return aten::lt;
3236       case '>':
3237         return aten::gt;
3238       case TK_LE:
3239         return aten::le;
3240       case TK_GE:
3241         return aten::ge;
3242       case TK_AND:
3243         return aten::__and__;
3244       case TK_OR:
3245         return aten::__or__;
3246       case TK_IS:
3247         return aten::__is__;
3248       case TK_ISNOT:
3249         return aten::__isnot__;
3250       case TK_NOT:
3251         return aten::__not__;
3252       case TK_FLOOR_DIV:
3253         return aten::floordiv;
3254       case TK_LSHIFT:
3255         return aten::__lshift__;
3256       case TK_RSHIFT:
3257         return aten::__rshift__;
3258       case '&':
3259         return aten::__and__;
3260       case '|':
3261         return aten::__or__;
3262       case '^':
3263         return aten::__xor__;
3264       case TK_IN:
3265         return aten::__contains__;
3266       default:
3267         throw std::runtime_error("unknown kind " + std::to_string(kind));
3268     }
3269   }
3270 
getOperatorOverloadtorch::jit::to_ir3271   std::string getOperatorOverload(int kind, size_t ninputs) {
3272     switch (kind) {
3273       case '+':
3274         return "__add__";
3275       case '-':
3276         return "__sub__";
3277       case TK_UNARY_MINUS:
3278         return "__neg__";
3279       case '~':
3280         return "__invert__";
3281       case '*':
3282         return "__mul__";
3283       case TK_POW:
3284         return "__pow__";
3285       case '/':
3286         return "__truediv__";
3287       case '%':
3288         return "__mod__";
3289       case TK_NE:
3290         return "__ne__";
3291       case TK_EQ:
3292         return "__eq__";
3293       case '<':
3294         return "__lt__";
3295       case '>':
3296         return "__gt__";
3297       case TK_LE:
3298         return "__le__";
3299       case TK_GE:
3300         return "__ge__";
3301       case '&':
3302         return "__and__";
3303       case '|':
3304         return "__or__";
3305       case '^':
3306         return "__xor__";
3307       case TK_IN:
3308         return "__contains__";
3309       case TK_LSHIFT:
3310         return "__lshift__";
3311       case TK_RSHIFT:
3312         return "__rshift__";
3313       default:
3314         throw std::runtime_error("unknown kind " + std::to_string(kind));
3315     }
3316   }
3317 
getNamedValuestorch::jit::to_ir3318   std::vector<NamedValue> getNamedValues(
3319       const TreeList& trees,
3320       bool maybe_unpack) {
3321     std::vector<NamedValue> values;
3322     for (const auto& tree : trees) {
3323       if (maybe_unpack && tree->kind() == TK_STARRED) {
3324         auto starred = Starred(tree);
3325         auto entries = emitSugaredExpr(starred.expr(), 1)
3326                            ->asTuple(starred.range(), method);
3327         for (const auto& entry : entries) {
3328           values.emplace_back(
3329               tree->range(), entry->asValue(starred.range(), method));
3330         }
3331       } else {
3332         values.emplace_back(tree->range(), emitExpr(Expr(tree)));
3333       }
3334     }
3335     return values;
3336   }
getNamedValuestorch::jit::to_ir3337   std::vector<NamedValue> getNamedValues(
3338       const List<Expr>& trees,
3339       bool maybe_unpack) {
3340     return getNamedValues(trees.tree()->trees(), maybe_unpack);
3341   }
3342 
getValuestorch::jit::to_ir3343   std::vector<Value*> getValues(const TreeList& trees, bool maybe_unpack) {
3344     return toValues(*graph, getNamedValues(trees, maybe_unpack));
3345   }
getValuestorch::jit::to_ir3346   std::vector<Value*> getValues(const List<Expr>& trees, bool maybe_unpack) {
3347     return getValues(trees.tree()->trees(), maybe_unpack);
3348   }
3349 
emitAttributestorch::jit::to_ir3350   std::vector<NamedValue> emitAttributes(const List<Attribute>& attributes) {
3351     return fmap(attributes, [&](const Attribute& attr) {
3352       return NamedValue(
3353           attr.range(), attr.name().name(), emitExpr(attr.value()));
3354     });
3355   }
3356 
checkApplyNumInputstorch::jit::to_ir3357   void checkApplyNumInputs(const Apply& apply, size_t expected_inputs) {
3358     const SourceRange& loc = apply.range();
3359     if (apply.inputs().size() != expected_inputs) {
3360       throw(
3361           ErrorReport(loc) << Var(apply.callee()).name().name()
3362                            << " expected exactly " << expected_inputs
3363                            << " arguments but found " << apply.inputs().size());
3364     }
3365     if (!apply.attributes().empty()) {
3366       throw(
3367           ErrorReport(loc) << Var(apply.callee()).name().name()
3368                            << " takes no keyword arguments");
3369     }
3370   }
3371 
checkApplyNumInputsRangetorch::jit::to_ir3372   void checkApplyNumInputsRange(
3373       const Apply& apply,
3374       size_t min_expected_inputs,
3375       size_t max_expected_inputs) {
3376     const SourceRange& loc = apply.range();
3377     size_t position_arg_size = apply.inputs().size();
3378     if (position_arg_size < min_expected_inputs ||
3379         position_arg_size > max_expected_inputs) {
3380       throw(
3381           ErrorReport(loc) << Var(apply.callee()).name().name()
3382                            << " expected to have number of arguments between "
3383                            << min_expected_inputs << " and "
3384                            << max_expected_inputs << " but found "
3385                            << position_arg_size);
3386     }
3387     if (!apply.attributes().empty()) {
3388       throw(
3389           ErrorReport(loc) << Var(apply.callee()).name().name()
3390                            << " takes no keyword arguments");
3391     }
3392   }
3393 
emitApplyExprtorch::jit::to_ir3394   std::shared_ptr<SugaredValue> emitApplyExpr(
3395       Apply& apply,
3396       size_t n_binders,
3397       const TypePtr& type_hint = nullptr) {
3398     auto sv = emitSugaredExpr(apply.callee(), 1);
3399     auto loc = apply.callee().range();
3400     if (auto special_form = dynamic_cast<SpecialFormValue*>(sv.get())) {
3401       return emitApplySpecialForm(special_form->form(), apply, sv, type_hint);
3402     }
3403     auto args = getNamedValues(apply.inputs(), true);
3404     auto kwargs = emitAttributes(apply.attributes());
3405     return sv->call(loc, method, args, kwargs, n_binders);
3406   }
3407 
3408   // this function handles expressions that look like apply statements
3409   // but have special evaluation rules for the arguments.
3410   // when adding a new case, only add a special form if it cannot be expressed
3411   // using the standard SugaredValue::call function, which enforces normal
3412   // evaluation order.
emitApplySpecialFormtorch::jit::to_ir3413   std::shared_ptr<SugaredValue> emitApplySpecialForm(
3414       Symbol form,
3415       Apply& apply,
3416       const std::shared_ptr<SugaredValue>& sv,
3417       const TypePtr& type_hint = nullptr) {
3418     switch (form) {
3419       case prim::fork: {
3420         auto& trees = apply.inputs().tree()->trees();
3421         if (trees.empty()) {
3422           throw(
3423               ErrorReport(apply) << "Expected at least one argument to fork()");
3424         }
3425         auto forked = emitSugaredExpr(Expr(trees[0]), 1);
3426         TreeList sliced_trees(trees.begin() + 1, trees.end());
3427         auto args = getNamedValues(sliced_trees, true);
3428         auto kwargs = emitAttributes(apply.attributes());
3429         return emitForkExpr(apply.range(), forked, args, kwargs);
3430       }
3431       case prim::awaitable: {
3432         auto tree = apply.inputs().tree();
3433         if (!tree || tree->trees().empty()) {
3434           throw(
3435               ErrorReport(apply)
3436               << "Expected at least one argument to awaitable()");
3437         }
3438         auto& trees = tree->trees();
3439         auto awaited = emitSugaredExpr(Expr(trees[0]), 1);
3440         TreeList sliced_trees(trees.begin() + 1, trees.end());
3441         auto args = getNamedValues(sliced_trees, true);
3442         auto kwargs = emitAttributes(apply.attributes());
3443         return emitAwaitableExpr(apply.range(), awaited, args, kwargs);
3444       }
3445       case prim::annotate: {
3446         checkApplyNumInputs(apply, 2);
3447         TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]);
3448         Value* expr = tryConvertToType(
3449             apply.range(),
3450             *graph,
3451             type,
3452             emitExpr(apply.inputs()[1], type),
3453             /*allow_conversions=*/true);
3454 
3455         std::stringstream why_not;
3456         if (!expr->type()->isSubtypeOfExt(*type, &why_not)) {
3457           throw(
3458               ErrorReport(apply.inputs())
3459               << "expected an expression of type " << type->repr_str()
3460               << " but found " << expr->type()->repr_str() << "\n"
3461               << why_not.str());
3462         }
3463 
3464         // None is a subtype of Optional[T], but we want to remember what T is
3465         // after annotation so that variables assigned to this None will still
3466         // get the right type. To do this, we make a None constant that
3467         // has the type Optional[T]
3468         if ((type->kind() == OptionalType::Kind ||
3469              (type->kind() == UnionType::Kind &&
3470               type->expect<UnionType>()->canHoldType(*NoneType::get()))) &&
3471             expr->type()->isSubtypeOf(*NoneType::get())) {
3472           Node* none = graph->createNone();
3473           none->output()->setType(type);
3474           graph->insertNode(none);
3475           expr = none->output();
3476         }
3477 
3478         return std::make_shared<SimpleValue>(expr);
3479       }
3480       case prim::rpc_async:
3481       case prim::rpc_sync:
3482       case prim::rpc_remote: {
3483         return emitRpcExpr(apply, form);
3484       }
3485       case prim::unchecked_cast: {
3486         checkApplyNumInputs(apply, 2);
3487         TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]);
3488         Value* v = emitExpr(apply.inputs()[1]);
3489         // avoid generating nested unchecked_casts because they are already
3490         // inserted during serialization
3491         if (v->node()->kind() != prim::unchecked_cast || *v->type() != *type) {
3492           v = graph->insertUncheckedCast(v, type);
3493         }
3494         return std::make_shared<SimpleValue>(v);
3495       } break;
3496       case prim::GetAttr: {
3497         checkApplyNumInputsRange(apply, 2, 3);
3498         auto obj = emitSugaredExpr(apply.inputs()[0], 1);
3499         auto selector = apply.inputs()[1];
3500         if (selector.kind() != TK_STRINGLITERAL) {
3501           throw(
3502               ErrorReport(apply)
3503               << "getattr's second argument must be a string literal");
3504         }
3505         const std::string& name = StringLiteral(selector).text();
3506 
3507         if (apply.inputs().size() == 2) {
3508           return obj->attr(apply.range(), method, name);
3509         } else {
3510           // 3 inputs form of getattr, the third argument is the default value
3511           // to return when attribute is not found
3512           if (obj->hasAttr(apply.range(), method, name)) {
3513             return obj->attr(apply.range(), method, name);
3514           } else {
3515             // attribute not found, just default val (3rd arg)
3516             return emitSugaredExpr(apply.inputs()[2], 1);
3517           }
3518         }
3519       } break;
3520       case prim::Uninitialized: {
3521         checkApplyNumInputs(apply, 1);
3522         TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]);
3523         auto out = graph->insertNode(graph->createUninitialized(type))
3524                        ->setSourceRange(apply.range());
3525         return std::make_shared<SimpleValue>(out->output());
3526       }
3527       case prim::TupleConstruct: {
3528         checkApplyNumInputs(apply, 1);
3529         auto arg = emitSugaredExpr(apply.inputs()[0], 1);
3530         auto inputs = arg->asTuple(apply.range(), method);
3531         auto inp_values = fmap(inputs, [&](const SugaredValuePtr& sv) {
3532           return sv->asValue(apply.range(), method);
3533         });
3534         return std::make_shared<SimpleValue>(
3535             graph->insertNode(graph->createTuple(inp_values))->output());
3536       }
3537       case prim::LegacyTypedConstructor: {
3538         // see legacy_tensor_generic_ctor_new
3539         // These legacy constructors do not follow schemas that can be
3540         // typed in native_functions.yaml / JIT type signature and are handled
3541         // here. Only the two common cases are handled initially:
3542         // "new(IntArrayRef size, *, Device? device=None)",
3543         // "new(PyObject* data, *, Device? device=None)",
3544         // Note: device argument is unused in the kernel
3545         auto args = getValues(apply.inputs(), true);
3546         auto kwargs = emitAttributes(apply.attributes());
3547         auto get_base_error_msg = [&]() {
3548           std::stringstream base_error_msg;
3549           base_error_msg
3550               << "Legacy Tensor Constructor only supports two schemas in TorchScript: \n";
3551           base_error_msg
3552               << "'new(IntArrayRef size, *, Device? device=None)',\n";
3553           base_error_msg << "'new(PyObject* data, *, Device? device=None)\n'";
3554           return base_error_msg;
3555         };
3556         if (kwargs.size() == 1 && kwargs[0].name() != "device") {
3557           throw(
3558               ErrorReport(apply) << get_base_error_msg().str() << "Got kwarg "
3559                                  << kwargs[0].name());
3560         }
3561         if (kwargs.size() > 1) {
3562           throw(
3563               ErrorReport(apply)
3564               << get_base_error_msg().str() << "Got multiple kwargs\n");
3565         }
3566         auto dtype = dynamic_cast<LegacyTensorConstructor*>(sv.get())->dtype();
3567         auto dtype_ivalue = graph->insertConstant(dtype);
3568 
3569         // supporting "new(IntArrayRef size, *, Device? device=None)", through
3570         // empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout?
3571         // layout=None, Device? device=None, bool? pin_memory=None,
3572         // MemoryFormat? memory_format=None) -> Tensor
3573         bool all_ints = std::all_of(args.begin(), args.end(), [](Value* v) {
3574           return v->type()->cast<IntType>();
3575         });
3576         if (args.empty()) {
3577           // empty inputs == torch.tensor([], dtype=....)
3578           auto inp_list =
3579               graph->insertNode(graph->createList(IntType::get(), {}))
3580                   ->output();
3581           return std::make_shared<SimpleValue>(graph->insert(
3582               aten::tensor,
3583               {inp_list},
3584               {NamedValue(apply.range(), "dtype", dtype_ivalue)}));
3585         } else if (all_ints) {
3586           auto inp_list =
3587               graph->insertNode(graph->createList(IntType::get(), args))
3588                   ->output();
3589           return std::make_shared<SimpleValue>(graph->insert(
3590               aten::empty,
3591               {inp_list},
3592               {NamedValue(apply.range(), "dtype", dtype_ivalue)}));
3593         } else if (args.size() == 1) {
3594           return std::make_shared<SimpleValue>(graph->insert(
3595               aten::tensor,
3596               {args[0]},
3597               {NamedValue(apply.range(), "dtype", dtype_ivalue)}));
3598         } else {
3599           throw(
3600               ErrorReport(apply)
3601               << get_base_error_msg().str()
3602               << "Got multiple positional arguments that were not all integers");
3603         }
3604       }
3605       case prim::isinstance: {
3606         checkApplyNumInputs(apply, 2);
3607         auto result = emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
3608         return std::make_shared<SimpleValue>(result.value());
3609       }
3610       case prim::tolist: {
3611         auto select = Select(apply.callee());
3612         auto value = select.value();
3613         auto operand = emitSugaredExpr(value, 1);
3614 
3615         if (!type_hint) {
3616           throw(
3617               ErrorReport(apply)
3618               << "Expected type hint for result of tolist()");
3619         }
3620 
3621         return std::make_shared<SimpleValue>(graph->insertToList(
3622             operand->asValue(value.range(), method), type_hint));
3623       }
3624       case prim::HasAttr: {
3625         checkApplyNumInputs(apply, 2);
3626         const auto result = emitHasAttr(apply.inputs()[0], apply.inputs()[1]);
3627         return std::make_shared<SimpleValue>(result.value());
3628       } break;
3629       // This represents the "__new__" method on classes
3630       // because it takes a ClassValue as input.
3631       // So if we see:
3632       //   Foo.__new__(Foo)
3633       // Foo is a ClassValue, calling `attr("__new__")` will return a
3634       // CreateObject special form.
3635       case prim::CreateObject: {
3636         if (apply.inputs().size() != 1) {
3637           throw(ErrorReport(apply) << "Only one argument to __new__ allowed");
3638         }
3639         auto arg = emitSugaredExpr(apply.inputs()[0], 1);
3640         auto class_arg = dynamic_cast<ClassValue*>(arg.get());
3641         if (!class_arg) {
3642           throw(
3643               ErrorReport(apply)
3644               << "Expected class value as argument to __new__, got "
3645               << arg->kind() << " instead");
3646         }
3647         auto createNode =
3648             graph->insertNode(graph->createObject(class_arg->type_));
3649         createNode->setSourceRange(apply.range());
3650         return std::make_shared<SimpleValue>(createNode->output());
3651       }
3652       // We construct the iterable tree here using the IterableTree
3653       // SugaredValue, The tree consists of SimpleValue, RangeValue or
3654       // IterableTree: For SimpleValues(List, Dict, etc) or RangeValue. We will
3655       // make them as tree leaves since we could get the loop information from
3656       // len() and get_item(). For IterableTree like zip(), enumerate(), we can
3657       // model them as a combination of leaves, and we emit a IterableTree value
3658       // to record the tree information
3659       case prim::range: {
3660         std::vector<Value*> input_vals =
3661             getValues(apply.inputs(), /*maybe_unpack=*/true);
3662         return std::make_shared<RangeValue>(apply.range(), method, input_vals);
3663       }
3664       case prim::enumerate: {
3665         const SourceRange& loc = apply.range();
3666         auto inputs = apply.inputs();
3667         auto input_size = inputs.size();
3668         auto attributes = apply.attributes();
3669         auto attribute_size = attributes.size();
3670         // enumerate(x) can be rewrite as subtrees:
3671         // IterableTree(RangeValue(0, math.inf), SimpleValue(x))
3672         Value* start_index = nullptr;
3673         if (input_size == 0) {
3674           throw(
3675               ErrorReport(loc)
3676               << "enumerate expected at least 1 arguments, got 0");
3677         }
3678 
3679         if (input_size == 2) {
3680           start_index = emitSugaredExpr(inputs[1], 1)->asValue(loc, method);
3681         }
3682         auto arg_size = input_size + attribute_size;
3683         if (arg_size > 2) {
3684           throw(
3685               ErrorReport(loc)
3686               << "enumerate expected at most 2 arguments, got " << arg_size);
3687         }
3688 
3689         if (attribute_size == 1) {
3690           if (attributes[0].name().name() != "start") {
3691             throw(
3692                 ErrorReport(loc)
3693                 << "enumerate expected kwarg name 'start', got '"
3694                 << attributes[0].name().name() << "'");
3695           }
3696           start_index =
3697               emitSugaredExpr(attributes[0].value(), 1)->asValue(loc, method);
3698         }
3699 
3700         std::vector<Value*> range_inputs;
3701         if (start_index != nullptr) {
3702           range_inputs.emplace_back(start_index);
3703         }
3704         Value* end = materializeConstant(
3705             std::numeric_limits<int64_t>::max(),
3706             *graph,
3707             loc,
3708             integral_constants);
3709         range_inputs.emplace_back(end);
3710         SugaredValuePtr expr_sv = emitSugaredExpr(inputs[0], 1);
3711         auto iterable_value = expr_sv->iter(loc, method);
3712 
3713         // range should have the same static length as the other iterable
3714         std::optional<int64_t> iter_static_len = iterable_value->staticLen();
3715         SugaredValuePtr range_sv = std::make_shared<RangeValue>(
3716             loc, method, range_inputs, iter_static_len);
3717 
3718         auto tree = std::make_shared<IterableTree>();
3719         tree->addChild(loc, method, range_sv);
3720         tree->addChild(loc, method, iterable_value);
3721         return tree;
3722       }
3723       case prim::zip: {
3724         // zip(x, y) can be rewrite as subtrees:
3725         // IterableTree(IterableTree(x), IterableTree(y))
3726         auto inputs = apply.inputs();
3727         if (inputs.empty()) {
3728           throw(
3729               ErrorReport(apply) << "zip expected at least 1 arguments, got 0");
3730         }
3731         auto iterable_tree = std::make_shared<IterableTree>();
3732         for (Expr expr : inputs) {
3733           auto iterable = emitSugaredExpr(expr, 1)->iter(apply.range(), method);
3734           iterable_tree->addChild(apply.range(), method, iterable);
3735         }
3736         return iterable_tree;
3737       }
3738       case prim::list: {
3739         return emitApplySpecialFormForList(apply, type_hint);
3740       }
3741       case prim::dict: {
3742         return emitApplySpecialFormForDict(apply, type_hint);
3743       }
3744       case aten::index: {
3745         const SourceRange& loc = apply.range();
3746         auto select = Select(apply.callee());
3747         auto self = emitSugaredExpr(select.value(), 1)->asValue(loc, method);
3748 
3749         auto inputs = apply.inputs();
3750         if (inputs.size() != 1) {
3751           throw(
3752               ErrorReport(apply)
3753               << "__getitem__ expected exactly 1 arguments, got "
3754               << inputs.size());
3755         }
3756         auto input =
3757             emitSugaredExpr(apply.inputs()[0], 1)->asValue(loc, method);
3758         if (input->type()->kind() == TypeKind::TupleType) {
3759           return std::make_shared<SimpleValue>(
3760               emitIndex(loc, self, createTupleUnpack(input)));
3761         }
3762         return std::make_shared<SimpleValue>(emitIndex(loc, self, {input}));
3763       }
3764       default:
3765         TORCH_INTERNAL_ASSERT(false, "unknown special form: ", form);
3766     }
3767   }
3768 
emitApplySpecialFormForListtorch::jit::to_ir3769   std::shared_ptr<SugaredValue> emitApplySpecialFormForList(
3770       Apply& apply,
3771       const TypePtr& type_hint = nullptr) {
3772     if (apply.inputs().empty()) {
3773       TypePtr type = type_hint ? type_hint : ListType::ofTensors();
3774       if (!type->cast<ListType>()) {
3775         throw(
3776             ErrorReport(apply.range())
3777             << "Expected list type annotation for list(), found "
3778             << type_hint->repr_str());
3779       }
3780       return std::make_shared<SimpleValue>(
3781           graph
3782               ->insertNode(graph->createList(
3783                   type->expectRef<ListType>().getElementType(), {}))
3784               ->output());
3785     }
3786     // list(iter) desugars to [_elem for _elem in iter]
3787     checkApplyNumInputs(apply, 1);
3788     auto iter_input = emitSugaredExpr(apply.inputs()[0], 1);
3789 
3790     // aten::list builtin op is registered for List and Str input
3791     // dispatch to the builtin op to avoid perf slowdown on existing uses
3792     if (auto simple = asSimple(iter_input)) {
3793       if (simple->type()->cast<ListType>() ||
3794           simple->type()->cast<StringType>()) {
3795         return std::make_shared<SimpleValue>(emitBuiltinCall(
3796             apply.range(), *method.graph(), aten::list, {simple}, {}));
3797       }
3798     }
3799     const std::string& iter_name = createTempName("$_iter");
3800     environment_stack->setSugaredVar(
3801         apply.range(),
3802         iter_name,
3803         iter_input,
3804         /*annotated_type=*/nullptr);
3805 
3806     const std::string& elem_name = createTempName("$_elem");
3807     auto ident =
3808         Var::create(apply.range(), Ident::create(apply.range(), elem_name));
3809     auto iter =
3810         Var::create(apply.range(), Ident::create(apply.range(), iter_name));
3811     auto lc = ListComp::create(apply.range(), ident, ident, iter);
3812     return std::make_shared<SimpleValue>(emitListComprehension(lc, type_hint));
3813   }
3814 
emitApplySpecialFormForDicttorch::jit::to_ir3815   std::shared_ptr<SugaredValue> emitApplySpecialFormForDict(
3816       Apply& apply,
3817       const TypePtr& type_hint = nullptr) {
3818     auto check_type_assignment_error = [&](const TypePtr& key_type,
3819                                            const TypePtr& value_type,
3820                                            const TypePtr& annotated_dict_type) {
3821       std::stringstream ss;
3822       std::stringstream err;
3823 
3824       auto annotated_k_type =
3825           annotated_dict_type->expect<DictType>()->getKeyType();
3826       auto annotated_v_type =
3827           annotated_dict_type->expect<DictType>()->getValueType();
3828 
3829       const auto is_key_subtype = key_type == annotated_k_type;
3830       const auto is_value_subtype =
3831           value_type->isSubtypeOfExt(annotated_v_type, &ss);
3832 
3833       if (!is_key_subtype) {
3834         err << "Generated key type " << key_type->repr_str()
3835             << " did not match the annotated key type, which was "
3836             << annotated_k_type->repr_str() << "\n";
3837       }
3838 
3839       if (!is_value_subtype) {
3840         err << "Generated value type " << value_type->repr_str()
3841             << " did not match the annotated value type, which was "
3842             << annotated_v_type->repr_str() << "\n"
3843             << ss.str();
3844       }
3845 
3846       if (!is_key_subtype || !is_value_subtype) {
3847         throw(ErrorReport(apply) << err.str());
3848       }
3849     };
3850 
3851     auto add_kwargs = [&](Value* dc_value) {
3852       NamedValue self = NamedValue(apply.range(), "self", dc_value);
3853       for (const auto& kwarg : apply.attributes()) {
3854         auto name = StringLiteral::create(kwarg.range(), kwarg.name().name());
3855         auto k = emitExpr(name);
3856         auto v = emitExpr(kwarg.value());
3857         NamedValue input_k = NamedValue(kwarg.range(), "", k);
3858         NamedValue input_v = NamedValue(kwarg.range(), "", v);
3859 
3860         check_type_assignment_error(k->type(), v->type(), dc_value->type());
3861 
3862         emitBuiltinCall(
3863             kwarg.range(),
3864             *graph,
3865             aten::_set_item,
3866             {self, input_k, input_v},
3867             {});
3868       }
3869     };
3870 
3871     auto treat_as_empty_container = [&]() {
3872       // true if `dict()`
3873       if (apply.inputs().empty() && !apply.attributes().empty()) {
3874         return true;
3875       }
3876       // true if `dict({})`
3877       if (!apply.inputs().empty() &&
3878           apply.inputs()[0].kind() == TK_DICT_LITERAL) {
3879         auto dict_lit = DictLiteral(apply.inputs()[0]);
3880         return dict_lit.key_inputs().empty() && dict_lit.value_inputs().empty();
3881       }
3882       // true if `dict([])`
3883       if (!apply.inputs().empty() &&
3884           apply.inputs()[0].kind() == TK_LIST_LITERAL) {
3885         auto list_lit = ListLiteral(apply.inputs()[0]);
3886         return list_lit.inputs().empty();
3887       }
3888       return false;
3889     };
3890 
3891     TypePtr annotated_union_type =
3892         type_hint && type_hint->isUnionType() ? type_hint : nullptr;
3893 
3894     auto add_union_cast = [&](Value* result) {
3895       Node* n =
3896           graph->insertNode(graph->create(prim::unchecked_cast, {result}));
3897       n->output()->setType(std::move(annotated_union_type));
3898       result = n->output();
3899     };
3900 
3901     TypePtr refined_type_hint = type_hint;
3902 
3903     std::vector<TypePtr> all_candidates = {};
3904 
3905     auto type_match = [&](const TypePtr& t) {
3906       return t->kind() == DictType::Kind;
3907     };
3908 
3909     if (type_hint && type_hint->kind() != DictType::Kind) {
3910       refineAndSetUnionTypeHintOrPopulateCandidatesVector(
3911           type_hint,
3912           &refined_type_hint,
3913           &all_candidates,
3914           "Dict",
3915           apply,
3916           type_match,
3917           [] {},
3918           [] {},
3919           /*is_dict_constructor=*/true);
3920     }
3921 
3922     if (!all_candidates.empty()) {
3923       throw(
3924           ErrorReport(apply)
3925           << "There are multiple candidate "
3926           << "Dict types in the Union type annotation `"
3927           << type_hint->repr_str()
3928           << "`, and full type inference is not yet supported for the "
3929           << "`dict()` constructor.");
3930     }
3931 
3932     // If possible, just cast what we have to a Dict and add the
3933     // kwargs by hand. This is not only the simplest solution; it also
3934     // hits cases like `dict(dict([1, 2, 3]))` or `dict(x)` (where `x`
3935     // is some previously-defined variable)
3936     if (!apply.inputs().empty()) {
3937       // TODO(@ansley): Fix this! We have a weird situation where the
3938       // dict constructor may be handed an internal container literal
3939       // or comprehension, in which case we'd throw an error because
3940       // the lhs type wouldn't match the rhs type (the compiler wouldn't
3941       // be able to tell that this was part of a nested expression). We
3942       // used to get around this by simply not passing `type_hint`, but
3943       // 1) that's bad, and 2) we actually need `type_hint` for
3944       // inference now that Union has been introduced.
3945       std::shared_ptr<SugaredValue> iter_input;
3946       try {
3947         iter_input = emitSugaredExpr(apply.inputs()[0], 1, type_hint);
3948       } catch (const ErrorReport&) {
3949         iter_input = emitSugaredExpr(apply.inputs()[0], 1);
3950       }
3951       if (auto simple = asSimple(iter_input)) {
3952         if (simple->type()->cast<DictType>()) {
3953           auto dc_value = emitBuiltinCall(
3954               apply.range(), *method.graph(), aten::dict, {simple}, {});
3955           add_kwargs(dc_value);
3956           if (annotated_union_type) {
3957             add_union_cast(dc_value);
3958           }
3959           return std::make_shared<SimpleValue>(dc_value);
3960         }
3961       }
3962     }
3963 
3964     // If we have a call with an empty container, or if we have a
3965     // call with kwargs only
3966     if (treat_as_empty_container()) {
3967       auto expr_list = List<Expr>::create(apply.range(), {});
3968       apply = Apply::create(
3969           apply.range(), apply.callee(), expr_list, apply.attributes());
3970     }
3971 
3972     // If we have a completely empty call to dict()
3973     if (apply.inputs().empty() && apply.attributes().empty()) {
3974       if (!refined_type_hint) {
3975         refined_type_hint =
3976             DictType::create(StringType::get(), TensorType::get());
3977       } else if (!all_candidates.empty()) {
3978         throw(
3979             ErrorReport(apply.range())
3980             << "Cannot determine the type "
3981             << "of an empty dict given the Union annotation `"
3982             << type_hint->repr_str() << "`, which contains multiple "
3983             << "candidate Dict types ");
3984       }
3985 
3986       TORCH_CHECK(
3987           refined_type_hint->kind() == DictType::Kind,
3988           "Expected a type annotation "
3989           "of Dict for dict constructor dict(), got ",
3990           type_hint->str());
3991 
3992       return std::make_shared<SimpleValue>(
3993           graph
3994               ->insertNode(graph->createDict(
3995                   refined_type_hint->expect<DictType>()->getKeyType(),
3996                   refined_type_hint->expect<DictType>()->getValueType(),
3997                   {},
3998                   {}))
3999               ->output());
4000     }
4001 
4002     // Special-case logic for if we have a dict comprehension
4003     if (!apply.inputs().empty() && apply.inputs()[0].kind() == TK_DICT_COMP) {
4004       auto dc = DictComp(apply.inputs()[0]);
4005       auto dc_value = emitDictComprehension(dc, refined_type_hint);
4006       add_kwargs(dc_value);
4007       return std::make_shared<SimpleValue>(dc_value);
4008     }
4009 
4010     // We can't feasibly register all possible key x value
4011     // combinations of new prim ops for the case that we use the
4012     // constructor with a dict literal. It makes much more sense
4013     // to transform the dict literal into a list of tuples so that
4014     // we can use the existing constructors
4015     if (!apply.inputs().empty() &&
4016         apply.inputs()[0].kind() == TK_DICT_LITERAL) {
4017       auto dict_lit = DictLiteral(apply.inputs()[0]);
4018       std::vector<Expr> zipped;
4019       zipped.reserve(dict_lit.key_inputs().size());
4020       TORCH_INTERNAL_ASSERT(
4021           dict_lit.key_inputs().size() == dict_lit.value_inputs().size());
4022       for (auto key_it = dict_lit.key_inputs().begin(),
4023                 val_it = dict_lit.value_inputs().begin();
4024            key_it != dict_lit.key_inputs().end();
4025            ++key_it, ++val_it) {
4026         auto tuple_inputs =
4027             List<Expr>::create(apply.range(), {*key_it, *val_it});
4028         auto tuple = TupleLiteral::create(apply.range(), tuple_inputs);
4029         zipped.push_back(tuple);
4030       }
4031       auto ll_values = List<Expr>::create(apply.range(), zipped);
4032       auto ll = ListLiteral::create(apply.range(), ll_values);
4033       auto expr_list = List<Expr>::create(apply.range(), {ll});
4034       // Change `apply` to a new Apply node holding a list of
4035       // tuples
4036       apply = Apply::create(
4037           apply.range(), apply.callee(), expr_list, apply.attributes());
4038     }
4039 
4040     // If we have kwargs to include, we'll take a similar approach
4041     // to the above logic and standardize the Apply node
4042     if (!apply.attributes().empty() &&
4043         (apply.inputs().empty() ||
4044          apply.inputs()[0].kind() == TK_LIST_LITERAL)) {
4045       std::vector<Expr> exprs;
4046       // Gather all the existing tuples in the input iterable
4047       if (!apply.inputs().empty()) {
4048         auto tuple_list = ListLiteral(apply.inputs()[0]).inputs();
4049         for (const auto& tuple : tuple_list) {
4050           exprs.push_back(tuple);
4051         }
4052       }
4053       // Create tuples out of each kwarg and gather them as well
4054       for (const auto& attr : apply.attributes()) {
4055         auto k = StringLiteral::create(apply.range(), attr.name().name());
4056         auto v = attr.value();
4057         auto tuple_inputs = List<Expr>::create(apply.range(), {k, v});
4058         auto tuple = TupleLiteral::create(apply.range(), tuple_inputs);
4059         exprs.push_back(tuple);
4060       }
4061       auto expr_list = List<Expr>::create(apply.range(), {exprs});
4062       auto ll = ListLiteral::create(apply.range(), expr_list);
4063       auto new_inputs = List<Expr>::create(apply.range(), {ll});
4064       auto new_kwargs = List<Attribute>::create(apply.range(), {});
4065       apply =
4066           Apply::create(apply.range(), apply.callee(), new_inputs, new_kwargs);
4067     }
4068 
4069     checkApplyNumInputs(apply, 1);
4070 
4071     auto iter_input = emitSugaredExpr(apply.inputs()[0], 1);
4072 
4073     const std::string& iter_name = createTempName("$_iter");
4074     const std::string& key_name = createTempName("$_key");
4075     const std::string& value_name = createTempName("$_value");
4076 
4077     auto key =
4078         Var::create(apply.range(), Ident::create(apply.range(), key_name));
4079     auto value =
4080         Var::create(apply.range(), Ident::create(apply.range(), value_name));
4081     auto target = TupleLiteral::create(
4082         apply.range(), List<Expr>::create(apply.range(), {key, value}));
4083     auto iter =
4084         Var::create(apply.range(), Ident::create(apply.range(), iter_name));
4085 
4086     environment_stack->setSugaredVar(
4087         apply.range(),
4088         iter_name,
4089         iter_input,
4090         /*annotated_type=*/nullptr);
4091 
4092     auto dc = DictComp::create(apply.range(), key, value, target, iter);
4093     auto result = emitDictComprehension(dc, refined_type_hint);
4094     add_kwargs(result);
4095 
4096     if (annotated_union_type) {
4097       add_union_cast(result);
4098     }
4099 
4100     return std::make_shared<SimpleValue>(result);
4101   }
4102 
emitExprtorch::jit::to_ir4103   Value* emitExpr(const Expr& tree, const TypePtr& type_hint = nullptr) {
4104     // Push the source range of a call in case compiling this function
4105     // triggers an error
4106     ErrorReport::CallStack::update_pending_range(tree.range());
4107     Value* out_val =
4108         emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method);
4109     // AnyType is the only user-exposed type which we don't unify to from
4110     // its subtypes, so we add a cast for use cases like
4111     // x : Any = 1 if cond else "str"
4112     if (type_hint == AnyType::get() && out_val->type() != AnyType::get()) {
4113       out_val = graph->insertUncheckedCast(out_val, type_hint);
4114     }
4115     return out_val;
4116   }
4117 
reverseComparisiontorch::jit::to_ir4118   NodeKind reverseComparision(NodeKind kind) {
4119     if (kind == aten::lt) {
4120       return aten::gt;
4121     } else if (kind == aten::le) {
4122       return aten::ge;
4123     } else if (kind == aten::gt) {
4124       return aten::lt;
4125     } else if (kind == aten::ge) {
4126       return aten::le;
4127     }
4128     throw std::runtime_error(
4129         "reverseComparision: unsupported NodeKind. File a bug");
4130   }
4131 
4132   // any expression that can produce a SugaredValue is handled here
4133   // expressions that only return a single Value* are handled in emitSimpleExpr
4134   // type_hint is set if there is a type that this value is expected to be
4135   // e.g. a : List[int] = []
4136   // or a = torch.jit.annotate(List[int], [])
4137   // the caller is responsible for checking that the result matches type_hint
4138   // emitSugaredExpr is free to ignore it.
emitSugaredExprtorch::jit::to_ir4139   std::shared_ptr<SugaredValue> emitSugaredExpr(
4140       const Expr& tree,
4141       size_t n_binders,
4142       const TypePtr& type_hint = nullptr) {
4143     switch (tree.kind()) {
4144       case TK_VAR: {
4145         return environment_stack->getSugaredVar(Var(tree).name());
4146       }
4147       case '.': {
4148         auto select = Select(tree);
4149         auto sv = emitSugaredExpr(select.value(), 1);
4150         return sv->attr(select.range(), method, select.selector().name());
4151       }
4152       case TK_APPLY: {
4153         auto apply = Apply(tree);
4154         return emitApplyExpr(apply, n_binders, type_hint);
4155       } break;
4156       case TK_SUBSCRIPT: {
4157         return emitSubscript(Subscript(tree), type_hint);
4158       } break;
4159       default:
4160         return std::make_shared<SimpleValue>(emitSimpleExpr(tree, type_hint));
4161     }
4162   }
4163 
emitUnaryOptorch::jit::to_ir4164   Value* emitUnaryOp(
4165       const TreeRef& tree,
4166       const std::string& magicMethod,
4167       const c10::Symbol& opSymbol) {
4168     const auto& inputs = tree->trees();
4169     auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
4170     auto val =
4171         asSimple(makeMagic(
4172                      magicMethod,
4173                      std::make_shared<BuiltinFunction>(opSymbol, std::nullopt))
4174                      ->call(tree->range(), method, named_values, {}, 0));
4175 
4176     // if we emitted the unary op and not some other overloaded function,
4177     // then try to constantfold
4178     if (val->node()->kind() != opSymbol) {
4179       return val;
4180     }
4181 
4182     auto maybe_out_stack = runNodeIfInputsAreConstant(val->node());
4183     if (!maybe_out_stack) {
4184       return val;
4185     }
4186     TORCH_INTERNAL_ASSERT(maybe_out_stack->size() == 1);
4187     return graph->insertConstant(maybe_out_stack->at(0), tree->range());
4188   }
4189 
4190   /**
4191    * Emit a fork expression, of the form:
4192    *   torch.jit.fork(forked, *args, **kwargs)
4193    */
emitForkExprtorch::jit::to_ir4194   std::shared_ptr<SugaredValue> emitForkExpr(
4195       SourceRange loc,
4196       const std::shared_ptr<SugaredValue>& forked,
4197       at::ArrayRef<NamedValue> args,
4198       at::ArrayRef<NamedValue> kwargs) {
4199     auto g = method.graph();
4200     TypePtr out_type;
4201 
4202     auto fork_node = g->insertNode(method.graph()->create(prim::forkClosure, 1))
4203                          ->setSourceRange(loc);
4204 
4205     // We create a fork by emitting a closure and setting the closure output
4206     // into the fork input. If a closure doesn't already exist, we create one.
4207     {
4208       WithInsertPoint insert(fork_node);
4209       if (ClosureValue* sv = dynamic_cast<ClosureValue*>(forked.get())) {
4210         Value* closure_output = sv->asValue(loc, method);
4211         Block* closure_block = closure_output->node()->blocks().at(0);
4212         TORCH_INTERNAL_ASSERT(closure_block->outputs().size() == 1);
4213         out_type = closure_block->outputs().at(0)->type();
4214         fork_node->addInput(closure_output);
4215       } else {
4216         auto emit_closure_body = [&](Block* closure_block) {
4217           auto fn_sugared_output = forked->call(loc, method, args, kwargs, 1);
4218           auto fn_simple_output = fn_sugared_output->asValue(loc, method);
4219           closure_block->registerOutput(fn_simple_output);
4220           out_type = fn_simple_output->type();
4221         };
4222         auto closure_value = emitClosure(emit_closure_body);
4223         fork_node->addInput(closure_value->asValue(loc, method));
4224       }
4225     }
4226     Value* node_output =
4227         fork_node->output()->setType(FutureType::create(out_type));
4228     return std::make_shared<SimpleValue>(node_output);
4229   }
4230 
emitAwaitableExprtorch::jit::to_ir4231   std::shared_ptr<SugaredValue> emitAwaitableExpr(
4232       SourceRange loc,
4233       const std::shared_ptr<SugaredValue>& awaited,
4234       at::ArrayRef<NamedValue> args,
4235       at::ArrayRef<NamedValue> kwargs) {
4236     auto g = method.graph();
4237     TypePtr out_type{};
4238 
4239     auto await_node =
4240         g->insertNode(method.graph()->create(prim::awaitableClosure, 1))
4241             ->setSourceRange(loc);
4242 
4243     {
4244       WithInsertPoint insert(await_node);
4245       if (auto sv = dynamic_cast<ClosureValue*>(awaited.get())) {
4246         Value* closure_output = sv->asValue(loc, method);
4247         Block* closure_block = closure_output->node()->blocks().at(0);
4248         TORCH_INTERNAL_ASSERT(closure_block->outputs().size() == 1);
4249         out_type = closure_block->outputs().at(0)->type();
4250         await_node->addInput(closure_output);
4251       } else {
4252         auto emit_closure_body = [&](Block* closure_block) {
4253           auto fn_sugared_output = awaited->call(loc, method, args, kwargs, 1);
4254           auto fn_simple_output = fn_sugared_output->asValue(loc, method);
4255           closure_block->registerOutput(fn_simple_output);
4256           out_type = fn_simple_output->type();
4257         };
4258         auto closure_value = emitClosure(emit_closure_body);
4259         await_node->addInput(closure_value->asValue(loc, method));
4260       }
4261     }
4262     Value* node_output =
4263         await_node->output()->setType(AwaitType::create(out_type));
4264     return std::make_shared<SimpleValue>(node_output);
4265   }
4266 
emitRpcExprtorch::jit::to_ir4267   std::shared_ptr<SugaredValue> emitRpcExpr(const Apply& apply, Symbol rpc_op) {
4268     // TODO: This is a temporary apporoach to enable calling user fucntion
4269     // through RPC in TorchScript,
4270     // Ideally, function value in JIT IR is first-class citizen and
4271     // The RPC C++ entry API can take c10::Function directly.
4272     size_t rpcMinInputs = 2;
4273     size_t rpcMaxInputs = 5;
4274     std::string op_name = rpc_op.toUnqualString();
4275     if (apply.inputs().size() < rpcMinInputs ||
4276         apply.inputs().size() > rpcMaxInputs) {
4277       throw(
4278           ErrorReport(apply)
4279           << "Possible forms of call to " << op_name << "(..) are\n"
4280           << op_name
4281           << "(dst_worker_name, user_callable, args, kwargs, timeout)\n"
4282           << op_name << "(dst_worker_name, user_callable, args, kwargs)\n"
4283           << op_name << "(dst_worker_name, user_callable, args)\n"
4284           << op_name << "(dst_worker_name, user_callable)\n"
4285           << "Now the number of arguments is " << apply.inputs().size());
4286     }
4287     if (!apply.attributes().empty()) {
4288       throw(
4289           ErrorReport(apply)
4290           << op_name << "(dst_worker_name, user_callable, args, kwargs)"
4291           << "does not support kwargs yet");
4292     }
4293     // TODO: Make rpc_op(..) support taking kwargs,
4294     // like rpc_async(to="worker1", func=my_func, args=(), kwargs={})
4295 
4296     auto& input_trees = apply.inputs().tree()->trees();
4297     Value* dst_worker_name_value = emitExpr(Expr(input_trees[0]));
4298     std::shared_ptr<SugaredValue> user_callable_sugared_value =
4299         emitSugaredExpr(Expr(input_trees[1]), 1);
4300     TORCH_CHECK(
4301         user_callable_sugared_value->kind() == "function",
4302         "user_callable should be a FunctionValue, it's now a ",
4303         user_callable_sugared_value->kind())
4304     // NB: This should be done using `std::dynamic_pointer_cast`
4305     // and assert `user_callable_function_value != nullptr`. But somehow on
4306     // macos std::dynamic_pointer_cast always returns
4307     // `user_callable_function_value` as a `nullptr`, even if
4308     // `user_callable_sugared_value->kind() == "function"`.
4309     std::shared_ptr<FunctionValue> user_callable_function_value =
4310         std::static_pointer_cast<FunctionValue>(user_callable_sugared_value);
4311     // If `kwargs` is an empty dict, users are allowed to not pass `kwargs`.
4312     // If `args` and `kwargs` are an empty tuple and an empty dict,
4313     // respectively, users are allowed to not pass `args` and `kwargs`.
4314 
4315     TreeList args_kwargs_timeout_trees(
4316         input_trees.begin() + 2, input_trees.end());
4317 
4318     // Get user callable.
4319     const auto& callablePtrs = user_callable_function_value->callees();
4320     TORCH_INTERNAL_ASSERT(
4321         callablePtrs.size() == 1,
4322         "User-provided callable size should be 1. Now it's",
4323         callablePtrs.size())
4324     Function* callablePtr = callablePtrs.at(0);
4325 
4326     const auto& functionSchema = callablePtr->getSchema();
4327     const SourceRange& loc = apply.range();
4328     auto graphPtr = method.graph();
4329 
4330     // Match FunctionSchema.
4331     std::vector<NamedValue> args;
4332     std::vector<NamedValue> kwargs;
4333     // Get args and kwargs as `NamedValue`s.
4334     // Similar to getNamedValues(..) and emitAttributes(..).
4335     if (!args_kwargs_timeout_trees.empty()) {
4336       // Unroll args from a Var that is known to be a Tuple.
4337       auto& args_tree = args_kwargs_timeout_trees[0];
4338       auto entry_sugared_values = emitSugaredExpr(Expr(args_tree), 1)
4339                                       ->asTuple(args_tree->range(), method);
4340       args.reserve(entry_sugared_values.size());
4341       for (const auto& entrie_sugared_value : entry_sugared_values) {
4342         args.emplace_back(
4343             args_tree->range(),
4344             entrie_sugared_value->asValue(args_tree->range(), method));
4345       }
4346       // NB: Can't do schema check on kwargs, given the RPC API is
4347       // rpc_op(to, user_callable, args, kwargs),
4348       // users can construct kwargs = {"first" + "_arg" : 1}.
4349       // Notice the key is determined at run time.
4350       // We can do it at compile time, unless one day the RPC API is
4351       // rpc_op(to, user_callable, arg_0, arg_1, kwarg_0="foo",
4352       // kwarg_1="bar")
4353     }
4354     matchSchema(functionSchema, loc, *graphPtr, args, kwargs);
4355 
4356     // Graph insert the QualifiedName as an constant input IR Value.
4357     const auto& qualname = callablePtr->qualname();
4358     IValue userCallableQualNameIValue(qualname.qualifiedName());
4359     Value* userCallableQualNameValue =
4360         graphPtr->insertConstant(userCallableQualNameIValue, loc);
4361 
4362     // Graph insert the corresponding RPC node to the graph.
4363     Node* rpc_node =
4364         graphPtr->insertNode(graphPtr->create(rpc_op, 1))->setSourceRange(loc);
4365     {
4366       WithInsertPoint insert(rpc_node);
4367       rpc_node->addInput(dst_worker_name_value);
4368       rpc_node->addInput(userCallableQualNameValue);
4369 
4370       for (const auto& tree : args_kwargs_timeout_trees) {
4371         rpc_node->addInput(emitExpr(Expr(tree)));
4372       }
4373     }
4374     Value* rpc_node_output = rpc_node->output();
4375 
4376     // Set output type from FunctionSchema and corresponding rpc_op.
4377     const std::vector<Argument>& returns = functionSchema.returns();
4378     TORCH_INTERNAL_ASSERT(returns.size() == 1);
4379     TypePtr output_type = nullptr;
4380     if (rpc_op == prim::rpc_async) {
4381       // rpc_async returns FutureType of the functionSchema's return type
4382       output_type = FutureType::create(returns[0].type());
4383     } else if (rpc_op == prim::rpc_sync) {
4384       // rpc_sync returns the functionSchema's return type
4385       output_type = returns[0].type();
4386     } else if (rpc_op == prim::rpc_remote) {
4387       // rpc_remote returns RRefType of the functionSchema's return type
4388       output_type = RRefType::create(returns[0].type());
4389     } else {
4390       throw(
4391           ErrorReport(apply)
4392           << rpc_op.toDisplayString() << " is not supported in TorchScript!'");
4393     }
4394     rpc_node_output->setType(output_type);
4395     return std::make_shared<SimpleValue>(rpc_node_output);
4396   }
4397 
emitBinaryOptorch::jit::to_ir4398   Value* emitBinaryOp(const TreeRef& tree) {
4399     const auto& inputs = tree->trees();
4400     auto kind = getNodeKind(tree->kind(), inputs.size());
4401     auto overload = getOperatorOverload(tree->kind(), inputs.size());
4402     auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
4403     if (tree->kind() == TK_IN) {
4404       // For `in` the arguments are in reverse order (the object being
4405       // checked is second)
4406       std::iter_swap(named_values.begin() + 0, named_values.begin() + 1);
4407     }
4408 
4409     // if this is adding two tuples, we deal with it here.
4410     // the reason is we can't specify the length of tuples
4411     // when registering custom aten::add.
4412     if (named_values[0].type()->kind() == TupleType::Kind &&
4413         named_values[1].type()->kind() == TupleType::Kind &&
4414         kind == aten::add) {
4415       auto first_tuple = createTupleUnpack(named_values[0].value(*graph)).vec();
4416       auto second_tuple =
4417           createTupleUnpack(named_values[1].value(*graph)).vec();
4418       first_tuple.insert(
4419           first_tuple.end(), second_tuple.begin(), second_tuple.end());
4420       return graph->insertNode(graph->createTuple(first_tuple))->output();
4421     }
4422 
4423     return asSimple(
4424         makeMagic(
4425             overload, std::make_shared<BuiltinFunction>(kind, std::nullopt))
4426             ->call(tree->range(), method, named_values, {}, 0));
4427   }
4428 
emitListLiteraltorch::jit::to_ir4429   Value* emitListLiteral(const ListLiteral& ll, const TypePtr& type_hint) {
4430     auto values = getValues(ll.inputs(), /*maybe_unpack=*/true);
4431 
4432     // Empty List Literals that are not assigned to variables
4433     // may match to any list type in schema matching,
4434     // but still default to List[Tensor] if assigned to a variable
4435     // or returned from a function
4436     // Restricting empty list matching to temporary values
4437     // avoids difficult to handle cases such as
4438     // a = []
4439     // b = a
4440     // if cond:
4441     //    b.append(2)
4442     // else:
4443     //    a.append("hi")
4444     // This is also the same behavior that C++ allows with {}
4445     // (cannot assign to a variable typed as auto)
4446     // These nodes will be removed in a later pass after initial compilation
4447     if (values.empty() && type_hint == nullptr) {
4448       auto node = graph->insertNode(graph->create(prim::EmptyListLiteral));
4449       node->output()->setType(ListType::ofTensors());
4450       return node->output();
4451     }
4452 
4453     // Determine the element type of the list. If we have a type hint
4454     // of `List[T]`, use `T`. If the list is non-empty, find the
4455     // greatest common supertype of all the list elements (defaulting to
4456     // `Any` as a catch-all supertype). Assume `[]` is `List[Tensor]`
4457     TypePtr inferred_elem_type = TensorType::get();
4458 
4459     TypePtr refined_type_hint = type_hint;
4460 
4461     // If `type_hint` is a Union/Optional, we're going to change it to
4462     // be the type of the rhs List, so we need to store the original
4463     // UnionType for later. `nullptr` means that we don't need to emit
4464     // an `unchecked_cast` node (either because we don't have a type
4465     // hint or because the type hint wasn't a Union)
4466     TypePtr annotated_union_type =
4467         refined_type_hint && refined_type_hint->isUnionType()
4468         ? refined_type_hint
4469         : nullptr;
4470 
4471     // This is used in the case that we have a Union annotation that
4472     // contains multiple Lists
4473     std::vector<TypePtr> all_candidates = {};
4474 
4475     if (refined_type_hint) {
4476       auto do_if_type_match = [&]() {
4477         auto list_type_hint = refined_type_hint->cast<ListType>();
4478         inferred_elem_type = list_type_hint->getElementType();
4479       };
4480 
4481       auto type_match = [&](const TypePtr& t) {
4482         return t->isSubtypeOf(AnyListType::get());
4483       };
4484 
4485       refineAndSetUnionTypeHintOrPopulateCandidatesVector(
4486           type_hint,
4487           &refined_type_hint,
4488           &all_candidates,
4489           "List",
4490           ll,
4491           type_match,
4492           do_if_type_match,
4493           do_if_type_match);
4494 
4495       if (!all_candidates.empty() && values.empty()) {
4496         throw(
4497             ErrorReport(ll)
4498             << "Cannot assign an empty list to a "
4499             << "variable annotated to be type " << refined_type_hint->repr_str()
4500             << " because there are multiple possible List "
4501             << "type candidates in the Union annotation");
4502       }
4503     }
4504 
4505     if (!values.empty()) {
4506       auto types = fmap(values, [](const Value* v) { return v->type(); });
4507 
4508       std::stringstream nowhere; // never used
4509 
4510       // We don't want to use `elem_type` as the final argument to
4511       // `unifyTypeList` because there's a chance that `elem_type` is
4512       // the Tensor default
4513       const auto elem_type_hint =
4514           refined_type_hint && refined_type_hint->kind() == ListType::Kind
4515           ? refined_type_hint->cast<ListType>()->getElementType()
4516           : nullptr;
4517 
4518       std::optional<TypePtr> unified_elem_type = unifyTypeList(
4519           types, nowhere, /*default_to_union=*/true, elem_type_hint);
4520 
4521       if (!refined_type_hint &&
4522           (*unified_elem_type)->kind() == UnionType::Kind) {
4523         TORCH_WARN(
4524             "List consists of heterogeneous types, which means",
4525             " that it has been typed as containing ",
4526             (*unified_elem_type)->repr_str(),
4527             ". To use any of the "
4528             "values in this List, it will be necessary to add an "
4529             "`assert isinstance` statement before first use to trigger "
4530             "type refinement.\n",
4531             ll.range().str());
4532       }
4533 
4534       if (all_candidates.empty() && refined_type_hint &&
4535           !(*unified_elem_type)->isSubtypeOf(*inferred_elem_type)) {
4536         throw(
4537             ErrorReport(ll)
4538             << "List type annotation `" << refined_type_hint->repr_str()
4539             << "` did not match the types of the given list elements,"
4540             << " which were unified to " << (*unified_elem_type)->repr_str());
4541       }
4542 
4543       if (!all_candidates.empty()) {
4544         refineAndSetListTypeHintFromCandidatesVector(
4545             all_candidates,
4546             type_hint,
4547             &refined_type_hint,
4548             *unified_elem_type,
4549             ll);
4550         inferred_elem_type =
4551             refined_type_hint->expect<ListType>()->getElementType();
4552       }
4553 
4554       // We only want to set `elem_type` if we don't have a type hint
4555       // to allow for the case that `*unified` is a subtype of
4556       // `type_hint`
4557       if (!refined_type_hint) {
4558         inferred_elem_type = *unified_elem_type;
4559       }
4560     }
4561 
4562     Node* result =
4563         graph->insertNode(graph->createList(inferred_elem_type, values));
4564     if (annotated_union_type) {
4565       Node* n = graph->insertNode(
4566           graph->create(prim::unchecked_cast, {result->output()}));
4567       n->output()->setType(std::move(annotated_union_type));
4568       result = n;
4569     }
4570 
4571     return result->output();
4572   }
4573 
emitDictLiteraltorch::jit::to_ir4574   Value* emitDictLiteral(DictLiteral dl, const TypePtr& type_hint) {
4575     auto key_trees = dl.key_inputs().tree()->trees();
4576     auto value_trees = dl.value_inputs().tree()->trees();
4577 
4578     AT_ASSERT(key_trees.size() == value_trees.size());
4579 
4580     std::vector<Value*> keys, values;
4581     TypePtr rhs_value_type;
4582 
4583     for (const auto i : c10::irange(key_trees.size())) {
4584       keys.push_back(emitExpr(Expr(key_trees[i])));
4585       values.push_back(emitExpr(Expr(value_trees[i])));
4586 
4587       if (i == 0) {
4588         rhs_value_type = values[i]->type();
4589       } else {
4590         if (keys[i - 1]->type()->kind() != keys[i]->type()->kind()) {
4591           throw(
4592               ErrorReport(key_trees[i])
4593               << "Dict keys must contain "
4594               << "only a single type. Expected: "
4595               << keys[i - 1]->type()->repr_str() << " but found "
4596               << keys[i]->type()->repr_str() << " instead");
4597         }
4598         rhs_value_type = *(unifyTypes(
4599             rhs_value_type, values[i]->type(), /*default_to_union=*/true));
4600       }
4601     }
4602 
4603     TypePtr refined_type_hint = type_hint;
4604 
4605     TypePtr annotated_union_type =
4606         type_hint && type_hint->isUnionType() ? type_hint : nullptr;
4607 
4608     std::vector<TypePtr> all_candidates = {};
4609 
4610     auto default_refined_type_hint_setter = [&]() {
4611       if (keys.empty()) {
4612         refined_type_hint =
4613             DictType::create(StringType::get(), TensorType::get());
4614       } else {
4615         refined_type_hint =
4616             DictType::create(keys.at(0)->type(), rhs_value_type);
4617         if (rhs_value_type->kind() == UnionType::Kind) {
4618           TORCH_WARN(
4619               "Dict values consist of heterogeneous types, which means",
4620               " that the dict has been typed as containing ",
4621               refined_type_hint->repr_str(),
4622               ". To use any of the values in this Dict, it will be "
4623               "necessary to add an `assert isinstance` statement before "
4624               "first use to trigger type refinement.\n",
4625               dl.range().str());
4626         }
4627       }
4628     };
4629 
4630     if (type_hint) {
4631       auto type_match = [&](const TypePtr& t) {
4632         return t->kind() == DictType::Kind;
4633       };
4634 
4635       refineAndSetUnionTypeHintOrPopulateCandidatesVector(
4636           type_hint,
4637           &refined_type_hint,
4638           &all_candidates,
4639           "Dict",
4640           dl,
4641           type_match,
4642           [] {},
4643           default_refined_type_hint_setter);
4644 
4645       if (!all_candidates.empty() && values.empty()) {
4646         throw(
4647             ErrorReport(dl)
4648             << "Cannot assign an empty dict to a "
4649             << "variable annotated to be type " << type_hint->repr_str()
4650             << " because there are multiple possible Dict "
4651             << "type candidates in the Union annotation");
4652       }
4653     } else {
4654       default_refined_type_hint_setter();
4655     }
4656 
4657     // We must have either a) specific key/value types already, or b) a
4658     // list of possible candidates
4659     TORCH_INTERNAL_ASSERT(!all_candidates.empty() || refined_type_hint);
4660 
4661     if (!values.empty()) {
4662       if (!all_candidates.empty()) {
4663         refineAndSetDictTypeHintFromCandidatesVector(
4664             all_candidates,
4665             type_hint,
4666             &refined_type_hint,
4667             keys[0]->type(),
4668             rhs_value_type,
4669             dl);
4670       }
4671 
4672       if (refined_type_hint->expect<DictType>()->getKeyType() !=
4673           keys.at(0)->type()) {
4674         throw(
4675             ErrorReport(dl)
4676             << "Type annotation was inferred to be "
4677             << refined_type_hint->repr_str()
4678             << "but the type of keys given by the dict literal is "
4679             << keys.at(0)->type()->repr_str());
4680       }
4681 
4682       if (!rhs_value_type->isSubtypeOf(
4683               refined_type_hint->expect<DictType>()->getValueType())) {
4684         throw(
4685             ErrorReport(dl)
4686             << "Type annotation was inferred to be `"
4687             << refined_type_hint->repr_str()
4688             << "`, but the type of values given by the dict literal is "
4689             << rhs_value_type->repr_str());
4690       }
4691     }
4692 
4693     Node* result = graph->insertNode(graph->createDict(
4694         refined_type_hint->expect<DictType>()->getKeyType(),
4695         refined_type_hint->expect<DictType>()->getValueType(),
4696         keys,
4697         values));
4698     if (annotated_union_type) {
4699       Node* n = graph->insertNode(
4700           graph->create(prim::unchecked_cast, {result->output()}));
4701       n->output()->setType(std::move(annotated_union_type));
4702       result = n;
4703     }
4704 
4705     return result->output();
4706   }
4707 
emitSimpleExprtorch::jit::to_ir4708   Value* emitSimpleExpr(
4709       const TreeRef& tree,
4710       const TypePtr& type_hint = nullptr) {
4711     switch (tree->kind()) {
4712       case TK_FLOOR_DIV:
4713       case '@': {
4714         const auto& inputs = tree->trees();
4715         auto kind = getNodeKind(tree->kind(), inputs.size());
4716         auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
4717         return emitBuiltinCall(
4718             tree->range(), *method.graph(), kind, named_values, {});
4719       }
4720       case '%': {
4721         auto lhs = emitSugaredExpr(Expr(tree->tree(0)), 0)
4722                        ->asValue(tree->tree(0)->range(), method);
4723         auto const& lhs_type = lhs->type();
4724         if (lhs_type == StringType::get()) {
4725           auto values = getValues(tree->trees(), /*maybe_unpack=*/false);
4726           auto node = graph->create(aten::percentFormat, values, 1)
4727                           ->setSourceRange(tree->range());
4728           Value* output = graph->insertNode(node)->output();
4729           output->setType(StringType::get());
4730           return output;
4731         } else {
4732           return emitBinaryOp(tree);
4733         }
4734       }
4735       case TK_IN:
4736       case TK_POW:
4737       case TK_NE:
4738       case TK_EQ:
4739       case '<':
4740       case '>':
4741       case TK_LE:
4742       case TK_GE:
4743       case '*':
4744       case '/':
4745       case '+':
4746       case '-':
4747       case '&':
4748       case '|':
4749       case '^':
4750       case TK_LSHIFT:
4751       case TK_RSHIFT:
4752         return emitBinaryOp(tree);
4753       case TK_IS:
4754       case TK_ISNOT:
4755       case TK_AND:
4756       case TK_OR:
4757       case TK_NOT: {
4758         return emitCondExpr(Expr(tree)).value();
4759       }
4760       case TK_UNARY_MINUS: {
4761         return emitUnaryOp(tree, "__neg__", aten::neg);
4762       }
4763       case '~': {
4764         return emitUnaryOp(tree, "__invert__", aten::bitwise_not);
4765       }
4766       case TK_STARRED: {
4767         throw(
4768             ErrorReport(tree)
4769             << "Unexpected starred expansion. File a bug report");
4770       }
4771       case TK_CONST: {
4772         return emitConst(Const(tree));
4773       } break;
4774       case TK_TRUE: {
4775         return graph->insertConstant(true, tree->range());
4776       } break;
4777       case TK_FALSE: {
4778         return graph->insertConstant(false, tree->range());
4779       } break;
4780       case TK_NONE: {
4781         return graph->insertConstant(IValue(), tree->range());
4782       } break;
4783       case TK_IF_EXPR: {
4784         return emitTernaryIf(TernaryIf(tree), type_hint);
4785       } break;
4786       case TK_STRINGLITERAL: {
4787         return emitStringLiteral(StringLiteral(tree));
4788       } break;
4789       case TK_LIST_LITERAL: {
4790         auto ll = ListLiteral(tree);
4791         return emitListLiteral(ll, type_hint);
4792       } break;
4793       case TK_TUPLE_LITERAL: {
4794         auto ll = TupleLiteral(tree);
4795         auto values = getValues(ll.inputs(), /*maybe_unpack=*/true);
4796         return graph->insertNode(graph->createTuple(values))->output();
4797       } break;
4798       case TK_DICT_LITERAL: {
4799         auto dc = DictLiteral(tree);
4800         return emitDictLiteral(dc, type_hint);
4801       } break;
4802       case TK_LIST_COMP: {
4803         auto lc = ListComp(tree);
4804         return emitListComprehension(lc, type_hint);
4805       } break;
4806       case TK_DICT_COMP: {
4807         auto dc = DictComp(tree);
4808         return emitDictComprehension(dc, type_hint);
4809       } break;
4810       default:
4811         throw(ErrorReport(tree) << "Cannot emit expr for: " << tree);
4812     }
4813   }
4814 
emitConsttorch::jit::to_ir4815   Value* emitConst(const Const& c) {
4816     if (c.isFloatingPoint())
4817       return materializeConstant(
4818           c.asFloatingPoint(), *graph, c.range(), fp_constants);
4819     else if (c.isComplex())
4820       return materializeConstant(
4821           c.asComplex(), *graph, c.range(), complex_constants);
4822     else
4823       return materializeConstant(
4824           c.asIntegral(), *graph, c.range(), integral_constants);
4825   }
4826 
emitStringLiteraltorch::jit::to_ir4827   Value* emitStringLiteral(const StringLiteral& c) {
4828     return insertConstant(*graph, c.text(), c.range());
4829   }
4830 
4831   // Desugars select indexing: tensor[i] -> tensor.select(dim, i)
emitSelecttorch::jit::to_ir4832   Value* emitSelect(
4833       const SourceRange& loc,
4834       Value* input,
4835       Value* dim,
4836       Value* index) {
4837     return emitBuiltinCall(loc, *graph, aten::select, {input, dim, index}, {});
4838   }
4839 
emitSliceOptorch::jit::to_ir4840   Value* emitSliceOp(
4841       const SourceRange& loc,
4842       Value* sliceable,
4843       Value* dim,
4844       Value* start,
4845       Value* end,
4846       Value* step) {
4847     std::vector<NamedValue> args;
4848     args.reserve(5);
4849     args.emplace_back(loc, "self", sliceable);
4850 
4851     // XXX: If list slicing becomes more complicated or stops using
4852     // aten::slice, we should separate it from this function.
4853     if (dim) {
4854       AT_ASSERT(sliceable->type()->isSubtypeOf(*TensorType::get()));
4855 
4856       args.emplace_back(dim);
4857     } else {
4858       AT_ASSERT(!sliceable->type()->isSubtypeOf(*TensorType::get()));
4859     }
4860 
4861     if (sliceable->type()->cast<TupleType>()) {
4862       std::vector<std::optional<NamedValue>> tuple_args;
4863       // since we are only dealing with tuple slicing, we try to keep
4864       // tuple args separate for now
4865       tuple_args.reserve(3);
4866 
4867       start ? tuple_args.emplace_back(start)
4868             : tuple_args.emplace_back(std::nullopt);
4869       end ? tuple_args.emplace_back(end)
4870           : tuple_args.emplace_back(std::nullopt);
4871       step ? tuple_args.emplace_back(step)
4872            : tuple_args.emplace_back(std::nullopt);
4873 
4874       return emitTupleSlice(loc, args[0], tuple_args);
4875     }
4876 
4877     // handling cases like x[0:2]. x[0:2:] is already handled from python
4878     if (!step) {
4879       step = graph->insertConstant(1, loc);
4880     }
4881 
4882     args.emplace_back(loc, "start", start);
4883     args.emplace_back(loc, "end", end);
4884     args.emplace_back(loc, "step", step);
4885     return emitBuiltinCall(loc, *graph, aten::slice, args, {});
4886   }
4887 
4888   // Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end,
4889   // 1)
emitSlicetorch::jit::to_ir4890   Value* emitSlice(
4891       const SourceRange& loc,
4892       Value* input,
4893       Value* dim, // Only used for tensor slicing
4894       const SliceExpr& slice) {
4895     Value* start = nullptr;
4896     Value* end = nullptr;
4897     Value* step = nullptr;
4898     if (slice.start().present()) {
4899       start = emitExpr(Expr(slice.start().get()));
4900     }
4901     if (slice.end().present()) {
4902       end = emitExpr(Expr(slice.end().get()));
4903     }
4904     if (slice.step().present()) {
4905       step = emitExpr(Expr(slice.step().get()));
4906     }
4907     return emitSliceOp(loc, input, dim, start, end, step);
4908   }
4909 
emitUnsqueezetorch::jit::to_ir4910   Value* emitUnsqueeze(const SourceRange& loc, Value* input, Value* dim_val) {
4911     return emitBuiltinCall(loc, *graph, aten::unsqueeze, {input, dim_val}, {});
4912   }
4913 
emitIndextorch::jit::to_ir4914   Value* emitIndex(
4915       const SourceRange& loc,
4916       Value* input,
4917       at::ArrayRef<Value*> indices) {
4918     // NB: the index of aten::index should be a type of List[Optional[Tensor]],
4919     // this is to support the case like t[:, :, 1] where : here indicates a
4920     // None/undefined tensor(optional tensor)
4921     auto* index =
4922         graph->insertNode(graph->createList(OptionalType::ofTensor(), indices))
4923             ->output();
4924     return emitBuiltinCall(loc, *graph, aten::index, {input, index}, {});
4925   }
4926 
4927   // Emits multidimensional slicing with int and slice indices.
4928   // Returns:
4929   // - Value*: the input after it has been indexed by int and slice indices.
4930   // - vector<Value*>: A list of tensor Value* indices that have not been
4931   // applied yet.
4932   //   Should be NULL at indices where sliceable (post-slicing) isn't indexed by
4933   //   a tensor.
emitIntAndSliceIndexingtorch::jit::to_ir4934   std::pair<Value*, std::vector<Value*>> emitIntAndSliceIndexing(
4935       const SourceRange& loc,
4936       Value* sliceable,
4937       const List<Expr>& subscript_exprs) {
4938     // Overall, to handle indexing (other than Tensors), we need to handle a
4939     // couple different things. For example, for x[1:3, None, 4], each of these
4940     // different index types (slice, None, and integer) result in different
4941     // number of dimensions. Slicing doesn't change the number of dimensions,
4942     // None adds a dimension, and integer removes a dimension. As these indexing
4943     // operations are applied left to right, the actual index that it's being
4944     // applied to depends on the previous operations. Ellipses indexing throws
4945     // another wrinkle. Ellipses selects any remaining unspecified dimensions.
4946     // Thus, for indexes following an ellipses, the actual index an indexing
4947     // operation is being applied to depends on the operations to the right.
4948     // Thus, we do two passes, one from left to right up until the ellipses, and
4949     // one from right to left.
4950 
4951     std::vector<Value*> tensor_indices;
4952 
4953     auto insert_value_for_dim = [&](int64_t dim) {
4954       return graph->insertConstant(dim, loc);
4955     };
4956     std::vector<int64_t> dims(subscript_exprs.size());
4957     std::vector<std::optional<Value*>> exprs(
4958         subscript_exprs.size(), std::nullopt);
4959 
4960     auto handle_indexing = [&](const Expr& subscript_expr,
4961                                size_t expr_idx,
4962                                int64_t dim,
4963                                bool is_reverse = false) {
4964       dims[expr_idx] = dim;
4965 
4966       // Slice expression case, does not represent a single index.
4967       if (subscript_expr.kind() == TK_SLICE_EXPR) {
4968         if (is_reverse) {
4969           return dim - 1;
4970         } else {
4971           return dim + 1;
4972         }
4973       }
4974 
4975       // Slice object case, does not represent a single index.
4976       auto subscript_sv = emitSugaredExpr(subscript_expr, 1);
4977       if (dynamic_cast<SliceValue*>(subscript_sv.get())) {
4978         if (is_reverse) {
4979           return dim - 1;
4980         } else {
4981           return dim + 1;
4982         }
4983       }
4984 
4985       TypePtr type_hint;
4986       if (subscript_expr.kind() == TK_NONE) {
4987         type_hint = NoneType::get();
4988       }
4989       auto index = emitExpr(subscript_expr, type_hint);
4990 
4991       // Accept list as subscript but convert it to a Tensor
4992       // since it's equivalent to indexing with Tensor.
4993       // The list can be a list literal or list variable.
4994       // Advanced indexing using list:
4995       // @torch.jit.script
4996       // def f(x):
4997       //   return x[[0, 1, 5]]  # or
4998       //   return x[[0, 1], [0, 1]]  # or
4999       //   return x[[[0, 1], [0, 1]], [[0, 1], [0, 1]]]  # or
5000       //   ls = [0, 1]
5001       //   return x[ls]
5002       // Statements above are equivalent to advanced indexing using Tensor:
5003       // @torch.jit.script
5004       // def f(x):
5005       //   return x[torch.tensor([0, 1, 5])]  # or
5006       //   return x[torch.tensor([0, 1]), torch.tensor([0, 1])]  # or
5007       //   return x[torch.tensor([[0, 1], [0, 1]]),
5008       //            torch.tensor([[0, 1], [0, 1]])]  # or
5009       //   ls = [0, 1]
5010       //   return x[torch.tensor(ls)]
5011       if (index->type()->kind() == c10::TypeKind::ListType) {
5012         // Always create index tensor as LongTensor.
5013         // This is to match Pytorch eager frontend behavior which accepts
5014         // indexing with float list.
5015         index = graph->insert(
5016             aten::tensor, {index}, {NamedValue("dtype", c10::kLong)});
5017       }
5018 
5019       exprs[expr_idx] = index;
5020       if (index->type()->isSubtypeOf(*NoneType::get())) {
5021         if (is_reverse) {
5022           return dim;
5023         } else {
5024           return dim + 1;
5025         }
5026       } else if (index->type() == IntType::get()) {
5027         if (is_reverse) {
5028           return dim - 1;
5029         } else {
5030           return dim;
5031         }
5032       } else if (index->type()->isSubtypeOf(*OptionalType::ofTensor())) {
5033         if (is_reverse) {
5034           throw(
5035               ErrorReport(loc)
5036               << "Ellipses followed by tensor indexing is currently not supported");
5037         } else {
5038           return dim + 1;
5039         }
5040       } else {
5041         throw(
5042             ErrorReport(loc)
5043             << "Unsupported operation: indexing tensor with unsupported index type '"
5044             << index->type()->repr_str()
5045             << "'. Only ints, slices, lists and tensors are supported");
5046       }
5047     };
5048 
5049     size_t idx = 0;
5050     int64_t dim = 0;
5051     for (; idx < subscript_exprs.size(); idx++) {
5052       auto subscript_expr = subscript_exprs[idx];
5053       if (subscript_expr.kind() == TK_DOTS) {
5054         break;
5055       }
5056       dim = handle_indexing(subscript_expr, idx, dim, /*is_reverse=*/false);
5057     }
5058     int64_t rdim = -1;
5059     for (size_t rev_idx = subscript_exprs.size() - 1; rev_idx > idx;
5060          rev_idx--) {
5061       auto subscript_expr = subscript_exprs[rev_idx];
5062       if (subscript_expr.kind() == TK_DOTS) {
5063         throw(
5064             ErrorReport(loc)
5065             << "An index can only have a single ellipsis ('...')");
5066       }
5067       rdim =
5068           handle_indexing(subscript_expr, rev_idx, rdim, /*is_reverse=*/true);
5069     }
5070     for (const auto i : c10::irange(exprs.size())) {
5071       if (!exprs[i].has_value()) {
5072         if (subscript_exprs[i].kind() == TK_SLICE_EXPR) {
5073           sliceable = emitSlice(
5074               loc,
5075               sliceable,
5076               insert_value_for_dim(dims[i]),
5077               SliceExpr(subscript_exprs[i]));
5078           continue;
5079         }
5080 
5081         if (subscript_exprs[i].kind() == TK_DOTS) {
5082           continue;
5083         }
5084 
5085         auto subscript_sv = emitSugaredExpr(subscript_exprs[i], 1);
5086         if (const auto slice_value =
5087                 dynamic_cast<SliceValue*>(subscript_sv.get())) {
5088           sliceable = emitSliceOp(
5089               loc,
5090               sliceable,
5091               insert_value_for_dim(dims[i]),
5092               slice_value->start(),
5093               slice_value->stop(),
5094               slice_value->step());
5095         }
5096 
5097         continue;
5098       }
5099       auto expr = exprs[i].value();
5100       if (expr->type()->isSubtypeOf(*NoneType::get())) {
5101         sliceable =
5102             emitUnsqueeze(loc, sliceable, insert_value_for_dim(dims[i]));
5103       } else if (expr->type() == IntType::get()) {
5104         sliceable =
5105             emitSelect(loc, sliceable, insert_value_for_dim(dims[i]), expr);
5106       } else if (expr->type()->isSubtypeOf(*OptionalType::ofTensor())) {
5107         tensor_indices.resize(dims[i] + 1);
5108         tensor_indices[dims[i]] = expr;
5109       } else {
5110         TORCH_INTERNAL_ASSERT(
5111             false, "Trying to process index type that we don't support.");
5112       }
5113     }
5114     // at::index takes in a List[Optional[Tensor]] where some dims can be None.
5115     // create None node with optional tensor output type and pass to at::index.
5116     for (auto& index : tensor_indices) {
5117       if (index == nullptr) {
5118         index = graph->insertNode(graph->createNone())->output();
5119       }
5120     }
5121     return std::make_pair(sliceable, tensor_indices);
5122   }
5123 
5124   // Desugars multidim slicing into slice/select/index/unsqueeze calls.
5125   //
5126   // XXX: Errors in user code are not elegantly reported.
5127   // Let's say someone were to do the following:
5128   //   @torch.jit.script
5129   //   def fn(x):
5130   //       return x[0, 1]
5131   //   fn(torch.randn(5))
5132   // Because we desugar this into two aten::select ops, the error message
5133   // complains about aten::select failing rather than there "not being
5134   // enough dimensions to index".
5135   //
5136   // The strategy is to slice and select the tensor for int and slices first
5137   // in one pass and then apply at::index on the result of the
5138   // slicing/selecting. Call the tensor after we've applied slice / select the
5139   // `sliced`. tensor_indices should have the same size as sliced.dim():
5140   // - tensor_indices[i] = NULL if we should not index `sliced` at dim i
5141   // - tensor_indices[i] = t if we should index `sliced` at dim i with tensor t.
emitMultidimSlicingtorch::jit::to_ir5142   Value* emitMultidimSlicing(
5143       const SourceRange& loc,
5144       Value* sliceable,
5145       const List<Expr>& subscript_exprs) {
5146     if (!sliceable->type()->isSubtypeOf(*TensorType::get())) {
5147       throw(
5148           ErrorReport(loc)
5149           << "Unsupported operation: attempted to use multidimensional "
5150           << "indexing on a non-tensor type");
5151     }
5152 
5153     std::vector<Value*> tensor_indices;
5154     std::tie(sliceable, tensor_indices) =
5155         emitIntAndSliceIndexing(loc, sliceable, subscript_exprs);
5156 
5157     if (tensor_indices.empty()) {
5158       // XXX: Might need to at::alias this when we support mutability
5159       return sliceable;
5160     }
5161 
5162     return emitIndex(loc, sliceable, tensor_indices);
5163   }
5164 
5165   // Desugars slice syntactic sugar tensor[begin:end] -> tensor.slice(begin,
5166   // end).
emitBasicSlicetorch::jit::to_ir5167   Value* emitBasicSlice(
5168       const SourceRange& loc,
5169       Value* sliceable,
5170       const List<Expr>& subscript_exprs) {
5171     AT_ASSERT(subscript_exprs.size() == 1);
5172     AT_ASSERT(subscript_exprs[0].kind() == TK_SLICE_EXPR);
5173     auto slice_exp = SliceExpr(subscript_exprs[0]);
5174     Value* maybe_dim = nullptr;
5175     if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
5176       // If the sliceable object is a tensor, specify a default dimension
5177       maybe_dim = graph->insertConstant(0, loc);
5178     }
5179     return emitSlice(loc, sliceable, maybe_dim, slice_exp);
5180   }
5181 
getAdjTupleIndextorch::jit::to_ir5182   int64_t getAdjTupleIndex(
5183       const SourceRange& loc,
5184       const TupleTypePtr& tuple_type,
5185       int64_t input_index,
5186       bool allow_out_of_bounds) {
5187     // set index to be positive to simplify logic in runtime
5188     int64_t adj_index = input_index;
5189     int64_t tuple_len = static_cast<int64_t>(tuple_type->elements().size());
5190     if (input_index < 0) {
5191       adj_index = tuple_len + input_index;
5192     }
5193     if (!allow_out_of_bounds && (adj_index >= tuple_len || adj_index < 0)) {
5194       throw(
5195           ErrorReport(loc) << "Tuple index out of range. Tuple is length "
5196                            << tuple_len << " and index is " << input_index);
5197     }
5198     return adj_index;
5199   }
5200 
5201   // When a list is marked const in a module, it gets converted to a tuple.
5202   // The result is indexing into a Tuple which contains only one type
5203   // is quite common. since indexing will likely be done in a for loop,
5204   // we do not want to invoke the overhead of converting the tuple to a list
5205   // each iter.
emitTupleIndextorch::jit::to_ir5206   Value* emitTupleIndex(
5207       const SourceRange& loc,
5208       Value* tuple_val,
5209       Value* idx_val) {
5210     auto tuple_typ = tuple_val->type()->cast<TupleType>();
5211     auto elems = tuple_typ->elements();
5212     TypePtr output_type;
5213     if (idx_val->type() != IntType::get()) {
5214       throw(ErrorReport(loc) << "tuple index must be an integer");
5215     }
5216     auto idx = toIValue(idx_val);
5217     if (!idx) {
5218       if (elems.empty() ||
5219           !convertibleToList(tuple_typ, ListType::create(elems[0]))) {
5220         throw(
5221             ErrorReport(loc)
5222             << "Cannot index into a " << tuple_typ->repr_str()
5223             << " with a non-integer literal because we cannot resolve the output type");
5224       }
5225       output_type = elems[0];
5226     } else {
5227       auto adj_index = getAdjTupleIndex(
5228           loc, tuple_typ, idx->toInt(), /*allow_out_of_bounds*/ false);
5229       output_type = elems[adj_index];
5230     }
5231     return graph
5232         ->insertNode(graph->createTupleIndex(tuple_val, idx_val, output_type))
5233         ->output();
5234   }
5235 
getSliceIndtorch::jit::to_ir5236   int64_t getSliceInd(Value* idx_val, const SourceRange& loc) {
5237     auto ivalue = toIValue(idx_val);
5238     if (ivalue && ivalue->isInt()) {
5239       return ivalue->to<int64_t>();
5240     } else {
5241       throw(
5242           ErrorReport(loc) << "tuple slice indices must be integer constants");
5243     }
5244   }
5245 
emitTupleSlicetorch::jit::to_ir5246   Value* emitTupleSlice(
5247       const SourceRange& loc,
5248       const NamedValue& tuple_val,
5249       const std::vector<std::optional<NamedValue>>& tuple_args) {
5250     auto tuple_type = tuple_val.value(*graph)->type()->expect<TupleType>();
5251     auto tuple_len = tuple_type->elements().size();
5252     auto beg_val = tuple_args[0];
5253     auto end_val = tuple_args[1];
5254     auto step = tuple_args[2];
5255 
5256     int64_t step_size = 1;
5257     if (step) {
5258       auto val = toIValue(step->value(*graph));
5259       TORCH_CHECK(val->isInt(), "Step size should always be an integer");
5260       step_size = val->to<int64_t>();
5261     }
5262 
5263     int64_t beg = std::numeric_limits<int64_t>::max();
5264     if (beg_val) {
5265       beg = getAdjTupleIndex(
5266           loc, tuple_type, getSliceInd(beg_val->value(*graph), loc), true);
5267     }
5268 
5269     int64_t end = std::numeric_limits<int64_t>::max();
5270     if (end_val) {
5271       end = getAdjTupleIndex(
5272           loc, tuple_type, getSliceInd(end_val->value(*graph), loc), true);
5273     }
5274 
5275     int64_t num_values = slice_indices_adjust(
5276         static_cast<int64_t>(tuple_len), &beg, &end, step_size);
5277 
5278     return graph
5279         ->insertNode(graph->createTupleSlice(
5280             tuple_val.value(*graph), beg, step_size, num_values))
5281         ->output();
5282   }
5283 
emitSubscripttorch::jit::to_ir5284   std::shared_ptr<SugaredValue> emitSubscript(
5285       const Subscript& subscript,
5286       TypePtr type_hint = nullptr) {
5287     const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1);
5288     const List<Expr>& subscript_exprs = subscript.subscript_exprs();
5289     const SourceRange& range = subscript.range();
5290     const SourceRange& val_range = subscript.value().range();
5291     if (subscript_exprs.size() != 1) {
5292       return std::make_shared<SimpleValue>(emitMultidimSlicing(
5293           range, sv->asValue(val_range, method), subscript_exprs));
5294     }
5295     if (subscript_exprs[0].kind() == TK_SLICE_EXPR) {
5296       // TODO @wconstab refactor using Symbol instead of string compare
5297       if (sv->kind() == "module") {
5298         // Slicing isn't currently implemented for Sequential/ModuleList,
5299         // but is implemented for Tuples, so a quick workaround is to
5300         // convert to a tuple of Modules for slicing support.
5301         auto s_tuple_val =
5302             sv->asTupleValue(val_range, method)->asValue(val_range, method);
5303         const SliceExpr& slice = SliceExpr(subscript_exprs[0]);
5304         std::vector<std::optional<NamedValue>> tuple_args;
5305         tuple_args.reserve(3);
5306         if (slice.start().present()) {
5307           auto begin = NamedValue(
5308               val_range, "begin", emitExpr(Expr(slice.start().get())));
5309           tuple_args.emplace_back(begin);
5310         } else {
5311           tuple_args.emplace_back(std::nullopt);
5312         }
5313 
5314         if (slice.end().present()) {
5315           auto end =
5316               NamedValue(val_range, "end", emitExpr(Expr(slice.end().get())));
5317           tuple_args.emplace_back(end);
5318         } else {
5319           tuple_args.emplace_back(std::nullopt);
5320         }
5321 
5322         if (slice.step().present()) {
5323           auto step =
5324               NamedValue(val_range, "step", emitExpr(Expr(slice.step().get())));
5325           tuple_args.emplace_back(step);
5326         } else {
5327           tuple_args.emplace_back(std::nullopt);
5328         }
5329         auto tupleSliceValue =
5330             emitTupleSlice(val_range, s_tuple_val, tuple_args);
5331         return std::make_shared<SimpleValue>(tupleSliceValue);
5332       } else {
5333         return std::make_shared<SimpleValue>(emitBasicSlice(
5334             range, sv->asValue(val_range, method), subscript_exprs));
5335       }
5336     } else {
5337       AT_ASSERT(subscript_exprs.size() == 1);
5338       Value* sliceable = sv->asValue(val_range, method);
5339 
5340       // In case of subscript expression being a Python Slice object.
5341       auto subscript_sv = emitSugaredExpr(subscript_exprs[0], 1);
5342       if (const auto slice_value =
5343               dynamic_cast<SliceValue*>(subscript_sv.get())) {
5344         Value* dim = nullptr;
5345         // aten::slice.tensor needs an additional `dim` input.
5346         if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
5347           dim = method.graph()->insertConstant(0, val_range);
5348         }
5349 
5350         Value* sliced = emitSliceOp(
5351             val_range,
5352             sliceable,
5353             dim,
5354             slice_value->start(),
5355             slice_value->stop(),
5356             slice_value->step());
5357         return std::make_shared<SimpleValue>(sliced);
5358       }
5359 
5360       // subscript is not a slice object, then it must be convertible to
5361       // a normal value.
5362       // Desugars gather syntactic sugar foo[i]
5363       Value* idx = subscript_sv->asValue(val_range, method);
5364       if (sliceable->type()->cast<TupleType>()) {
5365         return std::make_shared<SimpleValue>(
5366             emitTupleIndex(range, sv->asValue(val_range, method), idx));
5367       } else if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
5368         return std::make_shared<SimpleValue>(
5369             emitMultidimSlicing(range, sliceable, subscript_exprs));
5370       } else {
5371         return sv->getitem(range, method, idx, std::move(type_hint));
5372       }
5373     }
5374   }
5375 };
5376 
5377 struct FunctionResolver : public Resolver {
FunctionResolvertorch::jit::FunctionResolver5378   explicit FunctionResolver(
5379       Resolver* otherResolver,
5380       const std::unordered_map<std::string, Function*>& functionTable)
5381       : otherResolver_(otherResolver), functionTable_(functionTable) {}
5382 
resolveValuetorch::jit::FunctionResolver5383   std::shared_ptr<SugaredValue> resolveValue(
5384       const std::string& name,
5385       GraphFunction& m,
5386       const SourceRange& loc) override {
5387     auto it = functionTable_.find(name);
5388     if (it != functionTable_.end()) {
5389       return std::make_shared<FunctionValue>(it->second);
5390     }
5391     return otherResolver_->resolveValue(name, m, loc);
5392   }
5393 
resolveTypetorch::jit::FunctionResolver5394   TypePtr resolveType(const std::string& name, const SourceRange& loc)
5395       override {
5396     return otherResolver_->resolveType(name, loc);
5397   }
5398 
5399  private:
5400   Resolver* otherResolver_;
5401   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
5402   const std::unordered_map<std::string, Function*>& functionTable_;
5403 };
5404 
CompilationUnit(const std::string & source)5405 CompilationUnit::CompilationUnit(const std::string& source)
5406     : CompilationUnit() {
5407   // calles the define with native resolver to generate the graph for functions
5408   define(std::nullopt, source, nativeResolver(), nullptr);
5409 }
5410 
5411 // This pair represents a pair of functions (getter and setter) obtained from
5412 // compiling a Property.
5413 struct CompilationUnit::PropertyPair
5414     : public std::pair<std::unique_ptr<Function>, std::unique_ptr<Function>> {
PropertyPairtorch::jit::CompilationUnit::PropertyPair5415   PropertyPair(
5416       std::unique_ptr<Function> getter,
5417       std::unique_ptr<Function> setter) {
5418     TORCH_INTERNAL_ASSERT(getter, "Property pair must have defined getter")
5419     this->first = std::move(getter);
5420     this->second = std::move(setter);
5421   }
5422 
getGettertorch::jit::CompilationUnit::PropertyPair5423   std::unique_ptr<Function>& getGetter() {
5424     return this->first;
5425   }
5426 
getSettertorch::jit::CompilationUnit::PropertyPair5427   std::unique_ptr<Function>& getSetter() {
5428     return this->second;
5429   }
5430 };
5431 
define_property(const std::optional<c10::QualifiedName> & prefix,const Property & prop,const ResolverPtr & resolver,const Self * self,const std::unordered_map<std::string,Function * > & function_table,bool shouldMangle) const5432 CompilationUnit::PropertyPair CompilationUnit::define_property(
5433     const std::optional<c10::QualifiedName>& prefix,
5434     const Property& prop,
5435     const ResolverPtr& resolver,
5436     const Self* self,
5437     const std::unordered_map<std::string, Function*>& function_table,
5438     bool shouldMangle) const {
5439   // self must be defined because properties are features of classes and
5440   // modules.
5441   TORCH_INTERNAL_ASSERT(self);
5442 
5443   // Compile the getter function.
5444   std::unique_ptr<Function> getter_fn = define(
5445       prefix, prop.getter(), resolver, self, function_table, shouldMangle);
5446 
5447   // Compile the setter function if it exists.
5448   std::unique_ptr<Function> setter_fn = nullptr;
5449   if (prop.setter().present()) {
5450     setter_fn = define(
5451         prefix,
5452         prop.setter().get(),
5453         resolver,
5454         self,
5455         function_table,
5456         shouldMangle);
5457   }
5458 
5459   // Add the property to the class type definition.
5460   self->getClassType()->addProperty(
5461       prop.name().name(), getter_fn.get(), setter_fn.get());
5462 
5463   return PropertyPair(std::move(getter_fn), std::move(setter_fn));
5464 }
5465 
define(const std::optional<QualifiedName> & prefix,const Def & def,const ResolverPtr & resolver,const Self * self,const std::unordered_map<std::string,Function * > & function_table,bool shouldMangle,CompilationUnit::FunctionType type,std::optional<size_t> operator_set_version) const5466 std::unique_ptr<Function> CompilationUnit::define(
5467     const std::optional<QualifiedName>& prefix,
5468     const Def& def,
5469     const ResolverPtr& resolver,
5470     const Self* self,
5471     const std::unordered_map<std::string, Function*>& function_table,
5472     bool shouldMangle,
5473     CompilationUnit::FunctionType type,
5474     std::optional<size_t> operator_set_version) const {
5475   TORCH_INTERNAL_ASSERT(resolver);
5476   auto _resolver = resolver;
5477   if (!self) {
5478     // if self is defined, then these are methods and do not go into the
5479     // global namespace otherwise, they get defined together so we add them to
5480     // the function table so the methods can see each other
5481     _resolver =
5482         std::make_shared<FunctionResolver>(resolver.get(), function_table);
5483   }
5484   auto creator = [def, _resolver, self](GraphFunction& method) {
5485     // Store the function name so that it can be referenced if there is an error
5486     // while compiling this function
5487     std::string call_name = method.qualname().name();
5488     if (self) {
5489       auto atoms = method.qualname().atoms();
5490       // There should be at least a ClassName.method_name
5491       TORCH_INTERNAL_ASSERT(atoms.size() >= 2);
5492       call_name = atoms.at(atoms.size() - 2) + "." + atoms.at(atoms.size() - 1);
5493     }
5494     ErrorReport::CallStack call(call_name, def.range());
5495     to_ir(def, _resolver, self, method);
5496   };
5497   auto name = prefix ? QualifiedName(*prefix, def.name().name())
5498                      : QualifiedName(def.name().name());
5499   if (shouldMangle) {
5500     // If `shouldMangle` is set, we should generate a unique name for this
5501     // function if there is already an existing one.
5502     if (find_function(name)) {
5503       name = mangle(name);
5504     }
5505   }
5506 
5507   auto graph = std::make_shared<Graph>();
5508   graph->set_op_version(operator_set_version);
5509 
5510   auto fn = std::make_unique<GraphFunction>(std::move(name), graph, creator);
5511   if (self) {
5512     // Register this as a method on `self`'s type
5513     if (type == CompilationUnit::FunctionType::Hook) {
5514       self->getClassType()->addForwardHook(fn.get());
5515     } else if (type == CompilationUnit::FunctionType::PreHook) {
5516       self->getClassType()->addForwardPreHook(fn.get());
5517     } else {
5518       self->getClassType()->addMethod(fn.get());
5519     }
5520   }
5521   return fn;
5522 }
5523 
define(const std::optional<c10::QualifiedName> & prefix,const std::vector<Property> & properties,const std::vector<ResolverPtr> & propResolvers,const std::vector<Def> & definitions,const std::vector<ResolverPtr> & defResolvers,const Self * self,bool shouldMangle,std::optional<size_t> operator_set_version)5524 std::vector<Function*> CompilationUnit::define(
5525     const std::optional<c10::QualifiedName>& prefix,
5526     const std::vector<Property>& properties,
5527     const std::vector<ResolverPtr>& propResolvers,
5528     const std::vector<Def>& definitions,
5529     const std::vector<ResolverPtr>& defResolvers,
5530     const Self* self,
5531     bool shouldMangle,
5532     std::optional<size_t> operator_set_version) {
5533   TORCH_INTERNAL_ASSERT(definitions.size() == defResolvers.size());
5534   TORCH_INTERNAL_ASSERT(properties.size() == propResolvers.size());
5535   std::vector<Function*> functions;
5536   std::unordered_map<std::string, Function*> function_table;
5537 
5538   // Records fn in function_table, functions and with register_function.
5539   // This is done several times below, so this lambda helps avoid repeating
5540   // code.
5541   auto record_function = [&](std::unique_ptr<Function> fn) {
5542     function_table[fn->name()] = fn.get();
5543     functions.emplace_back(fn.get());
5544     this->register_function(std::move(fn));
5545   };
5546 
5547   for (const auto i : c10::irange(properties.size())) {
5548     PropertyPair property_fns = define_property(
5549         prefix,
5550         properties[i],
5551         propResolvers[i],
5552         self,
5553         function_table,
5554         shouldMangle);
5555 
5556     auto& getter_fn = property_fns.getGetter();
5557     auto& setter_fn = property_fns.getSetter();
5558 
5559     record_function(std::move(getter_fn));
5560 
5561     if (setter_fn) {
5562       record_function(std::move(setter_fn));
5563     }
5564   }
5565 
5566   for (const auto i : c10::irange(definitions.size())) {
5567     auto fn = define(
5568         prefix,
5569         definitions[i],
5570         defResolvers[i],
5571         self,
5572         function_table,
5573         shouldMangle,
5574         CompilationUnit::FunctionType::Method,
5575         operator_set_version);
5576 
5577     record_function(std::move(fn));
5578   }
5579 
5580   // We need to compile `__init__` first, since it can determine what attributes
5581   // are available to other methods. So reorder the definitions accordingly.
5582   for (auto& kv : function_table) {
5583     if (kv.first == "__init__") {
5584       kv.second->ensure_defined();
5585     }
5586   }
5587 
5588   for (Function* function : functions) {
5589     function->ensure_defined();
5590   }
5591 
5592   return functions;
5593 }
5594 
define_hooks(const std::optional<c10::QualifiedName> & prefix,const std::vector<Def> & hookDefs,const std::vector<ResolverPtr> & hookResolvers,const std::vector<Def> & preHookDefs,const std::vector<ResolverPtr> & preHookResolvers,const Self * self,bool shouldMangle)5595 void CompilationUnit::define_hooks(
5596     const std::optional<c10::QualifiedName>& prefix,
5597     const std::vector<Def>& hookDefs,
5598     const std::vector<ResolverPtr>& hookResolvers,
5599     const std::vector<Def>& preHookDefs,
5600     const std::vector<ResolverPtr>& preHookResolvers,
5601     const Self* self,
5602     bool shouldMangle) {
5603   TORCH_INTERNAL_ASSERT(hookDefs.size() == hookResolvers.size());
5604   TORCH_INTERNAL_ASSERT(preHookDefs.size() == preHookResolvers.size());
5605   std::vector<Function*> functions;
5606   std::unordered_map<std::string, Function*> function_table;
5607 
5608   // check hook for name collisions and redefinition
5609   auto check_collisions = [&](const Def& hook) -> Function* {
5610     auto name = prefix ? QualifiedName(*prefix, hook.name().name()).name()
5611                        : QualifiedName(hook.name().name()).name();
5612     // check if hook is already defined for this module
5613     auto found_hook = function_table.find(name);
5614     auto existing_hook =
5615         found_hook != function_table.end() ? found_hook->second : nullptr;
5616     // check if hook name is already defined on module as method
5617     if (existing_hook == nullptr) {
5618       TORCH_CHECK(
5619           self->getClassType()->findMethod(name) == nullptr &&
5620               self->getClassType()->findHook(name) == nullptr,
5621           "Can't define hook: ",
5622           name,
5623           " on class: ",
5624           self->getClassType()->repr_str(),
5625           " because a method or hook with that name already exists.");
5626     }
5627     return existing_hook;
5628   };
5629 
5630   // build_schema for checking
5631   auto build_schema = [&](const Def& hook_def,
5632                           const ResolverPtr& hook_res) -> FunctionSchema {
5633     ScriptTypeParser typeParser(hook_res);
5634     FunctionSchema schema =
5635         typeParser.parseSchemaFromDef(hook_def, true /* skip_self*/);
5636     // need to add self as the first because we skipped it
5637     std::vector<Argument> arguments;
5638     arguments.emplace_back(
5639         hook_def.decl().params()[0].ident().name(), self->getClassType());
5640     arguments.insert(
5641         arguments.end(), schema.arguments().begin(), schema.arguments().end());
5642     return schema.cloneWithArguments(arguments);
5643   };
5644 
5645   // define hooks
5646   for (const auto i : c10::irange(hookDefs.size())) {
5647     // check to see if already defined this hook
5648     auto existing_fn = check_collisions(hookDefs[i]);
5649     if (existing_fn != nullptr) {
5650       // add it to class type again so it's called
5651       self->getClassType()->addForwardHook(existing_fn);
5652       continue;
5653     }
5654     // define hook
5655     auto fn = define(
5656         prefix,
5657         hookDefs[i],
5658         hookResolvers[i],
5659         self,
5660         function_table,
5661         shouldMangle,
5662         CompilationUnit::FunctionType::Hook);
5663 
5664     function_table[fn->name()] = fn.get();
5665     functions.emplace_back(fn.get());
5666     this->register_function(std::move(fn));
5667     self->getClassType()->checkForwardHookSchema(
5668         i, build_schema(hookDefs[i], hookResolvers[i]));
5669     functions.back()->ensure_defined();
5670   }
5671 
5672   // define pre_hooks
5673   for (const auto i : c10::irange(preHookDefs.size())) {
5674     // check to see if already defined this hook
5675     auto existing_fn = check_collisions(preHookDefs[i]);
5676     if (existing_fn != nullptr) {
5677       // add it to class type again so it's called
5678       self->getClassType()->addForwardPreHook(existing_fn);
5679       continue;
5680     }
5681     // define pre_hook
5682     auto fn = define(
5683         prefix,
5684         preHookDefs[i],
5685         preHookResolvers[i],
5686         self,
5687         function_table,
5688         shouldMangle,
5689         CompilationUnit::FunctionType::PreHook);
5690 
5691     function_table[fn->name()] = fn.get();
5692     functions.emplace_back(fn.get());
5693     this->register_function(std::move(fn));
5694     self->getClassType()->checkForwardPreHookSchema(
5695         i, build_schema(preHookDefs[i], preHookResolvers[i]));
5696     functions.back()->ensure_defined();
5697   }
5698 }
5699 
define(const std::optional<QualifiedName> & prefix,const std::string & source,const ResolverPtr & resolver,const Self * self)5700 std::vector<Function*> CompilationUnit::define(
5701     const std::optional<QualifiedName>& prefix,
5702     const std::string& source,
5703     const ResolverPtr& resolver,
5704     const Self* self) {
5705   Parser p(std::make_shared<Source>(source, "<string>", 1));
5706   std::vector<Def> definitions;
5707   std::vector<ResolverPtr> resolvers;
5708   while (p.lexer().cur().kind != TK_EOF) {
5709     auto def = Def(p.parseFunction(/*is_method=*/bool(self)));
5710     definitions.push_back(def);
5711     resolvers.push_back(resolver);
5712   }
5713   return define(
5714       prefix,
5715       /*properties=*/{},
5716       /*propResolvers=*/{},
5717       definitions,
5718       resolvers,
5719       self);
5720 }
5721 
eraseListLiterals(std::shared_ptr<Graph> & graph)5722 static void eraseListLiterals(std::shared_ptr<Graph>& graph) {
5723   DepthFirstGraphNodeIterator it(graph);
5724 
5725   for (auto next_node = it.next(); next_node != nullptr;) {
5726     Node* node = next_node;
5727     next_node = it.next();
5728 
5729     if (node->kind() == prim::EmptyListLiteral) {
5730       if (node->hasUses()) {
5731         TORCH_INTERNAL_ASSERT(
5732             node->output()->type()->isSubtypeOf(ListType::ofTensors()));
5733 
5734         auto li = graph->createList(TensorType::get(), {});
5735         li->insertBefore(node);
5736         node->replaceAllUsesWith(li);
5737       }
5738       node->destroy();
5739     }
5740   }
5741 }
5742 
runCleanupPasses(std::shared_ptr<Graph> & to_clean)5743 void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {
5744   liftClosures(to_clean);
5745   inlineForkedClosures(to_clean);
5746 
5747   if (getInlineEverythingMode()) {
5748     Inline(*to_clean);
5749   }
5750 
5751   // these exist temporarily in initial compilation
5752   eraseListLiterals(to_clean);
5753 
5754   // remove any uses of tuples that we inserted that are not needed
5755   LowerSimpleTuples(to_clean);
5756 
5757   // full constant propagation runs ops with mutable inputs if it can
5758   // prove that the inputs are not mutated anywhere in the graph.
5759   // if a mutating node is removed in the graph (e.g. constant prop inlined a
5760   // a constant if) then the next time constant prop is run it might be able
5761   // to run nodes it was not able to previously, and the graph may change
5762   // (jitter) So we run only constant prop w immutable types here bc
5763   // successive runs of immutable constant prop does not change the graph
5764   ConstantPropagationImmutableTypes(to_clean);
5765 
5766   // Constant Pooling pass must be after ConstantPropagation, which can create
5767   // new constants that needs to be pooled.
5768   ConstantPooling(to_clean);
5769 
5770   // For jitter
5771   CanonicalizeOutputs(to_clean);
5772 
5773   // Annotate aten::warns so that each has its unique ID. This enables us to
5774   // mimic Python behavior of only emitting each warning only once.
5775   AnnotateWarns(to_clean);
5776 }
5777 
5778 // we consider _N where N is a number, to be a non-meaningful name
5779 // and do not record it as a unique name. This allows python printing to
5780 // be able to export and import more consistently named graphs
meaningfulName(const std::string & name)5781 bool meaningfulName(const std::string& name) {
5782   if (name.empty())
5783     return false;
5784   if (name[0] == '$')
5785     return false;
5786   if (name[0] != '_')
5787     return true;
5788   for (const auto i : c10::irange(1, name.size())) {
5789     if (!isdigit(name[i]))
5790       return true;
5791   }
5792   return false;
5793 }
5794 
define_interface(const c10::QualifiedName & qualifiedName,const ClassDef & classDef,ResolverPtr rcb,bool is_module)5795 void CompilationUnit::define_interface(
5796     const c10::QualifiedName& qualifiedName,
5797     const ClassDef& classDef,
5798     ResolverPtr rcb,
5799     bool is_module) {
5800   ScriptTypeParser typeParser(std::move(rcb));
5801   InterfaceTypePtr iface =
5802       InterfaceType::create(c10::QualifiedName(qualifiedName), is_module);
5803   for (const Stmt& stmt : classDef.body()) {
5804     if (stmt.kind() != TK_DEF) {
5805       throw(
5806           ErrorReport(stmt)
5807           << "interface declarations can only contain method definitions");
5808     }
5809     auto method_def = Def(stmt);
5810     if (!method_def.decl().return_type().present()) {
5811       throw(
5812           ErrorReport(method_def)
5813           << "interface declarations must have a return type annotated.");
5814     }
5815     FunctionSchema schema =
5816         typeParser.parseSchemaFromDef(method_def, /* skip_self*/ true);
5817     // need to add self as the first because we skipped it
5818     std::vector<Argument> arguments;
5819     arguments.emplace_back(method_def.decl().params()[0].ident().name(), iface);
5820     arguments.insert(
5821         arguments.end(), schema.arguments().begin(), schema.arguments().end());
5822     iface->addMethod(schema.cloneWithArguments(std::move(arguments)));
5823     // we need to make sure everything but the last element is just string
5824     // literals (aka comments) unless there is "pass" in between
5825     auto stmts_size = method_def.statements().size();
5826     for (size_t i = 0; i < stmts_size - 1; i++) {
5827       auto cur_statement = method_def.statements()[i];
5828       if (cur_statement.kind() == TK_EXPR_STMT) {
5829         auto expr = ExprStmt(cur_statement).expr();
5830         if (expr.kind() != TK_STRINGLITERAL) {
5831           throw(
5832               ErrorReport(method_def.range())
5833               << "interfaces declarations should only contain a single 'pass' statement.");
5834         }
5835       }
5836       // if we see a "pass", we just stop there
5837       if (cur_statement.kind() == TK_PASS) {
5838         this->register_type(iface);
5839         return;
5840       }
5841     }
5842 
5843     if (method_def.statements()[stmts_size - 1].kind() != TK_PASS) {
5844       throw(
5845           ErrorReport(method_def.range())
5846           << "interfaces declarations should contain 'pass' statement.");
5847     }
5848   }
5849   this->register_type(iface);
5850 }
5851 
5852 } // namespace torch::jit
5853