1 #include <torch/csrc/jit/frontend/ir_emitter.h>
2 #include <torch/csrc/jit/frontend/tree_views.h>
3
4 #include <c10/util/Exception.h>
5 #include <c10/util/StringUtil.h>
6 #include <c10/util/irange.h>
7 #include <caffe2/serialize/versions.h>
8 #include <torch/csrc/jit/api/function_impl.h>
9 #include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
10 #include <torch/csrc/jit/frontend/convert_to_ssa.h>
11 #include <torch/csrc/jit/frontend/lexer.h>
12 #include <torch/csrc/jit/frontend/parser.h>
13 #include <torch/csrc/jit/frontend/schema_matching.h>
14 #include <torch/csrc/jit/frontend/script_type_parser.h>
15 #include <torch/csrc/jit/ir/ir.h>
16 #include <torch/csrc/jit/passes/annotate_warns.h>
17 #include <torch/csrc/jit/passes/canonicalize.h>
18 #include <torch/csrc/jit/passes/constant_pooling.h>
19 #include <torch/csrc/jit/passes/constant_propagation.h>
20 #include <torch/csrc/jit/passes/dead_code_elimination.h>
21 #include <torch/csrc/jit/passes/inline_forked_closures.h>
22 #include <torch/csrc/jit/passes/inliner.h>
23 #include <torch/csrc/jit/passes/lift_closures.h>
24 #include <torch/csrc/jit/passes/lower_tuples.h>
25 #include <torch/csrc/jit/passes/normalize_ops.h>
26 #include <torch/csrc/jit/passes/replacement_of_old_operators.h>
27 #include <torch/csrc/jit/runtime/graph_iterator.h>
28 #include <torch/csrc/jit/runtime/interpreter.h>
29 #include <torch/csrc/jit/runtime/operator.h>
30 #include <torch/csrc/jit/runtime/slice_indices_adjust.h>
31 #include <torch/csrc/jit/testing/hooks_for_testing.h>
32
33 #include <torch/csrc/jit/ir/constants.h>
34
35 #include <c10/util/hash.h>
36 #include <optional>
37
38 #include <ATen/core/interned_strings.h>
39 #include <ATen/core/jit_type.h>
40 #include <torch/csrc/jit/frontend/error_report.h>
41 #include <climits>
42 #include <set>
43 #include <stack>
44
45 namespace {
reportSourceLocation(size_t file_size)46 bool reportSourceLocation(size_t file_size) {
47 if (file_size < 512ull * 1024) {
48 return true;
49 }
50 const char* enable_env =
51 std::getenv("PYTORCH_JIT_ENABLE_LARGE_SOURCE_LOCATION");
52 bool flag = true;
53 if (enable_env == nullptr || std::strcmp(enable_env, "0") == 0 ||
54 std::strcmp(enable_env, "FALSE") == 0 ||
55 std::strcmp(enable_env, "false") == 0) {
56 flag = false;
57 }
58 return flag;
59 }
60 } // namespace
61
62 namespace torch::jit {
63
64 using FunctionTable = std::unordered_map<std::string, Function&>;
65 using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
66 using TypeTable = std::unordered_map<std::string, TypePtr>;
67 using AttributeMap = std::unordered_map<std::string, Const>;
68 using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
69
70 struct Refinement {
Refinementtorch::jit::Refinement71 Refinement(std::string identifier, TypePtr type)
72 : identifier_(std::move(identifier)), type_(std::move(type)) {}
identifiertorch::jit::Refinement73 const std::string& identifier() const {
74 return identifier_;
75 }
typetorch::jit::Refinement76 TypePtr type() const {
77 return type_;
78 }
79
80 private:
81 std::string identifier_;
82 TypePtr type_;
83 };
84
85 struct RefinementSet {
86 // When a comparison like x is None is made, we associate type refinements
87 // with its true value and its false value. If a boolean that has refinements
88 // associated with it is used in a conditional of an if statement, the true
89 // and false refinements are inserted into the corresponding blocks
90 using Refinements = std::vector<Refinement>;
91
RefinementSettorch::jit::RefinementSet92 RefinementSet(Refinements true_refinements, Refinements false_refinements)
93 : true_refinements_(std::move(true_refinements)),
94 false_refinements_(std::move(false_refinements)) {}
RefinementSettorch::jit::RefinementSet95 RefinementSet(Refinement single) : RefinementSet({std::move(single)}, {}) {}
RefinementSettorch::jit::RefinementSet96 RefinementSet(Refinement single_true, Refinement single_false)
97 : RefinementSet(
98 Refinements({std::move(single_true)}),
99 Refinements({std::move(single_false)})) {}
100 RefinementSet() = default; // empty
Andtorch::jit::RefinementSet101 RefinementSet And(const RefinementSet& rhs) const {
102 // if the result of an AND is true, both a & b had to be true,
103 // so we take the union of a.true_refinements and b.true_refinements.
104 // if the result is false, either a or b could have been false,
105 // so we take their intersection.
106 return RefinementSet(
107 unionSet(true_refinements_, rhs.true_refinements_),
108 intersectSet(false_refinements_, rhs.false_refinements_));
109 }
Ortorch::jit::RefinementSet110 RefinementSet Or(const RefinementSet& rhs) const {
111 // if the result of an OR is true, either a & b could have been true,
112 // so we take the intersection of a.true_refinements & b.true_refinements.
113 // if the result is false, both a and b had to be false,
114 // so we take their union.
115 return RefinementSet(
116 intersectSet(true_refinements_, rhs.true_refinements_),
117 unionSet(false_refinements_, rhs.false_refinements_));
118 }
119
Nottorch::jit::RefinementSet120 RefinementSet Not() const {
121 return RefinementSet(false_refinements_, true_refinements_);
122 }
activeRefinementstorch::jit::RefinementSet123 const std::vector<Refinement> activeRefinements() const {
124 return true_refinements_;
125 }
126
127 private:
sameVartorch::jit::RefinementSet128 static bool sameVar(const Refinement& a, const Refinement& b) {
129 return a.identifier() == b.identifier();
130 }
unionSettorch::jit::RefinementSet131 static Refinements unionSet(const Refinements& a, const Refinements& b) {
132 Refinements result = a;
133 for (const Refinement& r : b) {
134 auto it =
135 std::find_if(result.begin(), result.end(), [&](const Refinement& e) {
136 return e.identifier() == r.identifier();
137 });
138 if (it == result.end()) {
139 result.push_back(r);
140 } else if (*it->type() != *r.type()) {
141 // we only keep refinements when they exactly match one
142 // refinement type, for instance, we do not attempt to refine:
143 // isinstance(x, float) and isinstance(x, int)
144 result.erase(it);
145 }
146 }
147 return result;
148 }
intersectSettorch::jit::RefinementSet149 static Refinements intersectSet(const Refinements& a, const Refinements& b) {
150 Refinements result;
151 for (const Refinement& r : a) {
152 auto it = std::find_if(b.begin(), b.end(), [&](const Refinement& e) {
153 return e.identifier() == r.identifier();
154 });
155 if (it != b.end() && r.type() == it->type()) {
156 result.push_back(r);
157 }
158 }
159 return result;
160 }
161
162 Refinements true_refinements_;
163 Refinements false_refinements_;
164 };
165
166 struct CondValue {
CondValuetorch::jit::CondValue167 CondValue(
168 Value* value,
169 RefinementSet refinements,
170 std::optional<bool> static_if)
171 : value_(value),
172 refinements_(std::move(refinements)),
173 static_if_(static_if) {}
CondValuetorch::jit::CondValue174 CondValue(
175 Graph& g,
176 const SourceRange& loc,
177 bool static_value,
178 RefinementSet refinements)
179 : value_(g.insertConstant(static_value, loc)),
180 refinements_(std::move(refinements)),
181 static_if_(static_value) {}
valuetorch::jit::CondValue182 Value* value() const {
183 return value_;
184 }
refinementstorch::jit::CondValue185 const RefinementSet& refinements() const {
186 return refinements_;
187 }
staticIftorch::jit::CondValue188 std::optional<bool> staticIf() const {
189 return static_if_;
190 }
191
192 private:
193 Value* value_;
194 RefinementSet refinements_;
195 std::optional<bool>
196 static_if_; // certain expression cause us to emit a static if statement
197 // this value is present if this is the case.
198 // this is not equivalent to value_ being a constant
199 // it is possible for value_ to be constant but for
200 // the expression that produced it to not trigger the
201 // static if behavior. e.g. use of a variable assigned
202 // to a constant
203 };
204
205 enum NoneStatus { ALWAYS, MAYBE, NEVER };
canBeNone(Value * v)206 static NoneStatus canBeNone(Value* v) {
207 if (v->node()->mustBeNone()) {
208 return ALWAYS;
209 }
210 if (v->type()->kind() == OptionalType::Kind ||
211 (v->type()->kind() == UnionType::Kind &&
212 v->type()->expect<UnionType>()->canHoldType(*NoneType::get()))) {
213 return MAYBE;
214 }
215 return NEVER;
216 }
217
asSimple(const SugaredValuePtr & value)218 static Value* asSimple(const SugaredValuePtr& value) {
219 if (SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
220 return sv->getValue();
221 }
222 return nullptr;
223 }
224
makeMagic(const std::string & name,const SugaredValuePtr & base)225 static std::shared_ptr<MagicMethod> makeMagic(
226 const std::string& name,
227 const SugaredValuePtr& base) {
228 return std::make_shared<MagicMethod>(name, base);
229 }
230
231 // Auxiliary data structure for desugaring variable binding into our always
232 // explicitly scoped language as we descend down nested control structures in
233 // the frontend (which themselves don't introduce scopes)
234 //
235 // The Environment keeps track of two tables, one for values which are not first
236 // class and a type table for values which are. When a first class value
237 // is set in the environment, we emit a prim::Store which sets the
238 // name of the variable to appropriate type, and when a first-class value is
239 // referenced we emit a prim::Load that generates a value of the appropriate
240 // type.
241 //
242 // a = 1
243 // print(a)
244 // becomes:
245 // = prim::Store[name="a"](%a.1)
246 // %a : int = prim::Load[name="a"]()
247 // prim::Print(%a)
248
249 struct Environment {
Environmenttorch::jit::Environment250 Environment(
251 GraphFunction& method,
252 ResolverPtr resolver,
253 Block* b,
254 std::shared_ptr<Environment> next = nullptr)
255 : method(method),
256 resolver(std::move(resolver)),
257 b(b),
258 next(std::move(next)) {}
259
260 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
261 GraphFunction& method;
262 ResolverPtr resolver;
263 std::unordered_map<std::string, std::function<std::string()>> error_messages;
264 Block* b;
265
266 std::shared_ptr<Environment> next;
267
268 // set type error in the lowest environment. if the variable is used after an
269 // error has been set, then we will use the more informative error message
setVariableTypeErrortorch::jit::Environment270 void setVariableTypeError(
271 const std::string& name,
272 std::function<std::string()> msg) {
273 auto runner = this;
274 while (runner->next) {
275 runner = runner->next.get();
276 }
277 runner->error_messages[name] = std::move(msg);
278 }
279
280 // see if type error has been set for a variable
findVariableTypeErrortorch::jit::Environment281 std::optional<std::string> findVariableTypeError(const std::string& name) {
282 auto runner = this;
283 while (runner->next) {
284 runner = runner->next.get();
285 }
286 auto msg = runner->error_messages.find(name);
287 if (msg != runner->error_messages.end()) {
288 return msg->second();
289 } else {
290 return std::nullopt;
291 }
292 }
293
insertLoadtorch::jit::Environment294 SugaredValuePtr insertLoad(const std::string& name, const TypePtr& type) {
295 auto g = b->owningGraph();
296 auto load = g->insertNode(g->createLoad(name, type));
297 if (meaningfulName(name)) {
298 load->output()->setDebugName(name);
299 }
300 return std::make_shared<SimpleValue>(load->output());
301 }
302
303 // note: type is not always the same as v->type(), e.g.
304 // type: Optional[Tensor]
305 // v->type(): Tensor
insertStoretorch::jit::Environment306 void insertStore(
307 const std::string& name,
308 const SourceRange& loc,
309 Value* v,
310 TypePtr type) {
311 auto g = b->owningGraph();
312 g->insertNode(g->createStore(name, v))->setSourceRange(loc);
313 type_table[name] = std::move(type);
314 }
315
findInThisFrametorch::jit::Environment316 SugaredValuePtr findInThisFrame(const std::string& name) {
317 auto it = value_table.find(name);
318 if (it != value_table.end()) {
319 return it->second;
320 }
321 auto it2 = type_table.find(name);
322 if (it2 != type_table.end()) {
323 return insertLoad(name, it2->second);
324 }
325 return nullptr;
326 }
327
findInParentFrametorch::jit::Environment328 SugaredValuePtr findInParentFrame(const std::string& name) {
329 return next ? next->findInAnyFrame(name) : nullptr;
330 }
331
setTypetorch::jit::Environment332 void setType(const std::string& name, TypePtr type) {
333 type_table[name] = std::move(type);
334 }
335
findInAnyFrametorch::jit::Environment336 SugaredValuePtr findInAnyFrame(const std::string& name) {
337 for (auto runner = this; runner; runner = runner->next.get()) {
338 if (auto r = runner->findInThisFrame(name)) {
339 return r;
340 }
341 }
342 return nullptr;
343 }
344
blocktorch::jit::Environment345 Block* block() {
346 return b;
347 }
348
setVartorch::jit::Environment349 void setVar(const SourceRange& loc, const std::string& name, Value* value) {
350 setSugaredVar(
351 loc,
352 name,
353 std::make_shared<SimpleValue>(value),
354 /*annotated_type=*/nullptr);
355 }
356
setSugaredVartorch::jit::Environment357 void setSugaredVar(
358 const SourceRange& loc,
359 const std::string& name,
360 SugaredValuePtr value,
361 const TypePtr& annotated_type) {
362 Value* as_simple_value = asSimple(value);
363 if (as_simple_value && !as_simple_value->hasDebugName() &&
364 meaningfulName(name) &&
365 // note: if the value wasn't defined in this block, we might be giving a
366 // name only used inside this block to a value outside of this. this is
367 // not normally helpful for debugging and causes import/export jitter.
368 as_simple_value->node()->owningBlock() == block()) {
369 as_simple_value->setDebugName(name);
370 }
371 // prevent re-assignment involving any sugared values
372 // any reassignment like:
373 // a = ...
374 // while ...
375 // a = ..
376 // requires 'a' to be first-class in the graph since its value depends on
377 // control flow
378 if (auto parent = findInParentFrame(name)) {
379 if (annotated_type) {
380 throw(
381 ErrorReport(loc)
382 << "Attempting to declare and annotate the type of variable '"
383 << name << "' but it is already defined in an outer block");
384 }
385 if (!as_simple_value) {
386 throw(
387 ErrorReport(loc)
388 << "Cannot re-assign '" << name << "' to a value of type "
389 << value->kind() << " because " << name
390 << " is not a first-class value. Only reassignments to first-class values are allowed");
391 }
392 Value* simple_parent = asSimple(parent);
393 if (!simple_parent) {
394 throw(
395 ErrorReport(loc)
396 << "Cannot re-assign '" << name << "' because it has type "
397 << value->kind() << " and " << name
398 << " is not a first-class value. Only reassignments to first-class values are allowed");
399 }
400
401 auto parent_type = unshapedType(simple_parent->type());
402 as_simple_value = tryConvertToType(
403 loc,
404 *b->owningGraph(),
405 parent_type,
406 as_simple_value,
407 /*allow_conversions=*/true);
408 std::stringstream why_not;
409 if (!as_simple_value->type()->isSubtypeOfExt(*parent_type, &why_not)) {
410 auto error = ErrorReport(loc);
411 error << "Variable '" << name << "' previously had type "
412 << simple_parent->type()->repr_str()
413 << " but is now being assigned to a value of type "
414 << as_simple_value->type()->repr_str();
415
416 // Special-cased error msg if we're trying to assign to a tensor list.
417 if (simple_parent->type()->kind() == TypeKind::ListType &&
418 as_simple_value->type()->kind() == TypeKind::ListType) {
419 error << "\nEmpty lists default to List[Tensor]. Add a variable "
420 "annotation to the assignment to create an empty list "
421 "of another type (torch.jit.annotate(List[T, []]) where T "
422 "is the type of elements in the list for Python 2)";
423 }
424 error << "\n" << why_not.str();
425 throw ErrorReport(error);
426 }
427 }
428 if (as_simple_value) {
429 if (annotated_type &&
430 !as_simple_value->type()->isSubtypeOf(*annotated_type)) {
431 throw(
432 ErrorReport(loc)
433 << "Variable '" << name << "' is annotated with type "
434 << annotated_type->repr_str()
435 << " but is being assigned to a value of type "
436 << as_simple_value->type()->repr_str());
437 }
438 auto value_store_type =
439 annotated_type ? annotated_type : as_simple_value->type();
440 insertStore(name, loc, as_simple_value, value_store_type);
441 } else {
442 value_table[name] = std::move(value);
443 }
444 }
445
getSugaredVartorch::jit::Environment446 SugaredValuePtr getSugaredVar(const Ident& ident, bool required = true) {
447 return getSugaredVar(ident.name(), ident.range());
448 }
getVartorch::jit::Environment449 Value* getVar(const Ident& ident) {
450 return getSugaredVar(ident)->asValue(ident.range(), method);
451 }
452
throwVarNotFoundErrortorch::jit::Environment453 void throwVarNotFoundError(
454 const std::string& ident,
455 const SourceRange& range) {
456 // check if this value was not emitted in an if statement because of a
457 // type mismatch. if it was, then we print a more informative error msg
458 if (auto msg = findVariableTypeError(ident)) {
459 throw(ErrorReport(range) << *msg << "and was used here");
460 }
461 throw(ErrorReport(range) << "undefined value " << ident);
462 }
463
getSugaredVartorch::jit::Environment464 SugaredValuePtr getSugaredVar(
465 const std::string& ident,
466 const SourceRange& range,
467 bool required = true) {
468 auto retval = findInAnyFrame(ident);
469
470 if (!retval) {
471 static std::unordered_map<std::string, SugaredValuePtr> globals = {
472 {"print", std::make_shared<PrintValue>()},
473 {"tuple", SpecialFormValue::create(prim::TupleConstruct)},
474 {"float",
475 makeMagic(
476 "__float__",
477 std::make_shared<CastValue>(FloatType::get(), aten::Float))},
478 {"complex",
479 makeMagic(
480 "__complex__",
481 std::make_shared<CastValue>(ComplexType::get(), aten::Complex))},
482 {"int",
483 makeMagic(
484 "__int__",
485 std::make_shared<CastValue>(IntType::get(), aten::Int))},
486 {"bool",
487 makeMagic(
488 "__bool__",
489 std::make_shared<CastValue>(BoolType::get(), aten::Bool))},
490 {"str",
491 makeMagic(
492 "__str__",
493 std::make_shared<CastValue>(StringType::get(), aten::str))},
494 {"getattr", SpecialFormValue::create(prim::GetAttr)},
495 {"hasattr", SpecialFormValue::create(prim::HasAttr)},
496 {"isinstance", SpecialFormValue::create(prim::isinstance)},
497 // todo(zach): remove when we can correctly export torch.full via ONNX
498 // or we have implicit conversion that can convert numbers to tensors
499 {"_to_tensor",
500 std::make_shared<CastValue>(TensorType::get(), prim::NumToTensor)},
501 {"len",
502 makeMagic(
503 "__len__",
504 std::make_shared<BuiltinFunction>(aten::len, std::nullopt))},
505 {"hex",
506 makeMagic(
507 "__hex__",
508 std::make_shared<BuiltinFunction>(aten::hex, std::nullopt))},
509 {"oct",
510 makeMagic(
511 "__oct__",
512 std::make_shared<BuiltinFunction>(aten::oct, std::nullopt))},
513 {"round",
514 makeMagic(
515 "__round__",
516 std::make_shared<BuiltinFunction>(aten::round, std::nullopt))},
517 {"hash", std::make_shared<BuiltinFunction>(aten::hash, std::nullopt)},
518 {"id", std::make_shared<BuiltinFunction>(prim::id, std::nullopt)},
519 {"min", std::make_shared<BuiltinFunction>(prim::min, std::nullopt)},
520 {"max", std::make_shared<BuiltinFunction>(prim::max, std::nullopt)},
521 {"abs", std::make_shared<BuiltinFunction>(prim::abs, std::nullopt)},
522 {"all", std::make_shared<BuiltinFunction>(aten::all, std::nullopt)},
523 {"any", std::make_shared<BuiltinFunction>(aten::any, std::nullopt)},
524 {"divmod",
525 std::make_shared<BuiltinFunction>(aten::divmod, std::nullopt)},
526 {"sum", std::make_shared<BuiltinFunction>(aten::sum, std::nullopt)},
527 {"list", SpecialFormValue::create(prim::list)},
528 {"dict", SpecialFormValue::create(prim::dict)},
529 {"ord", std::make_shared<BuiltinFunction>(aten::ord, std::nullopt)},
530 {"chr", std::make_shared<BuiltinFunction>(aten::chr, std::nullopt)},
531 {"bin", std::make_shared<BuiltinFunction>(aten::bin, std::nullopt)},
532 {"pow", std::make_shared<BuiltinFunction>(aten::pow, std::nullopt)},
533 {"range", SpecialFormValue::create(prim::range)},
534 {"zip", SpecialFormValue::create(prim::zip)},
535 {"enumerate", SpecialFormValue::create(prim::enumerate)},
536 {"rangelist",
537 std::make_shared<BuiltinFunction>(prim::rangelist, std::nullopt)},
538 {"sorted",
539 std::make_shared<BuiltinFunction>(aten::sorted, std::nullopt)},
540 // Only AssertionError is bound so that we can use it from emitAssert,
541 // all other exceptions should be resolved at the Python level
542 {"AssertionError",
543 std::make_shared<ExceptionValue>("AssertionError")},
544 };
545 auto it = globals.find(ident);
546 if (it != globals.end()) {
547 retval = it->second;
548 }
549 }
550
551 if (!retval) {
552 if (auto type = resolver->resolveType(ident, range)) {
553 if (auto tuple_type = type->cast<TupleType>()) {
554 retval = std::make_shared<NamedTupleConstructor>(tuple_type);
555 }
556 }
557 }
558
559 if (!retval) {
560 retval = resolver->resolveValue(ident, method, range);
561 }
562
563 if (!retval) {
564 if (auto type = resolver->resolveType(ident, range)) {
565 if (auto class_type = type->cast<ClassType>()) {
566 retval = std::make_shared<ClassValue>(class_type);
567 }
568 }
569 }
570
571 if (!retval && required) {
572 throwVarNotFoundError(ident, range);
573 }
574
575 return retval;
576 }
577
getVartorch::jit::Environment578 Value* getVar(const std::string& ident, const SourceRange& range) {
579 return getSugaredVar(ident, range)->asValue(range, method);
580 }
581
removeVartorch::jit::Environment582 void removeVar(const Ident& ident, bool check_if_removed = false) {
583 bool removed = false;
584
585 for (auto runner = this; runner; runner = runner->next.get()) {
586 auto a = runner->value_table.erase(ident.name());
587 auto b = runner->type_table.erase(ident.name());
588 removed = a || b;
589 }
590
591 if (check_if_removed && !removed) {
592 throwVarNotFoundError(ident.name(), ident.range());
593 }
594 }
595
definedVariablestorch::jit::Environment596 std::vector<std::string> definedVariables() {
597 std::vector<std::string> result;
598 for (auto& kv : type_table) {
599 result.push_back(kv.first);
600 }
601 return result;
602 }
603
604 private:
605 TypeTable type_table;
606 ValueTable value_table;
607 };
608
609 template <class T, class Hash>
materializeConstant(T val,Graph & graph,const SourceRange & r,std::unordered_map<T,Value *,Hash> & map)610 static Value* materializeConstant(
611 T val,
612 Graph& graph,
613 const SourceRange& r,
614 std::unordered_map<T, Value*, Hash>& map) {
615 auto existing_constant = map.find(val);
616 if (existing_constant != map.end()) {
617 return existing_constant->second;
618 }
619
620 WithInsertPoint guard(graph.block()->nodes().front());
621 auto new_constant = graph.insertConstant(val, r);
622 map[val] = new_constant;
623
624 return new_constant;
625 }
626
isSupportedListElementType(const TypePtr & type)627 inline bool isSupportedListElementType(const TypePtr& type) {
628 return type->isSubtypeOf(*TensorType::get()) ||
629 type->isSubtypeOf(*NumberType::get());
630 }
631
632 // Information for each def being emitted.
633 // Defs can be nested to support closures so we need a stack of this information
634 // Currently records information about the functions return type.
635 struct DefContext {
636 TypePtr declared_return_type_; // nullptr if not annotated
637 TypePtr merged_return_type_; // nullptr if a Return has not been seen yet
638 };
639
640 enum class LoopStatus { NOT_IN_LOOP, IN_LOOP, IN_UNROLLED_LOOP };
641
642 struct WithLoopStatus {
WithLoopStatustorch::jit::WithLoopStatus643 WithLoopStatus(LoopStatus* prev, LoopStatus new_status)
644 : prev_ptr_(prev), prev_value_(*prev) {
645 *prev = new_status;
646 }
~WithLoopStatustorch::jit::WithLoopStatus647 ~WithLoopStatus() {
648 *prev_ptr_ = prev_value_;
649 }
650
651 private:
652 LoopStatus* prev_ptr_;
653 LoopStatus prev_value_;
654 };
655
656 struct to_ir {
to_irtorch::jit::to_ir657 to_ir(
658 const Def& def,
659 ResolverPtr resolver_,
660 const Self* self,
661 GraphFunction& method) // method being constructed
662 : method(method),
663 graph(method.graph()),
664 resolver(std::move(resolver_)),
665 typeParser_(resolver),
666 environment_stack(nullptr) {
667 AT_ASSERT(resolver);
668 pushFrame(graph->block(), /*starts_def=*/true);
669
670 // Type annotations exclude explicitly typing the "self" parameter, so in
671 // the case that this is a method with self we expect one fewer parameter
672 // annotation than the number of parameters this Def takes.
673 if (self && def.decl().params().empty()) {
674 throw(
675 ErrorReport(def.decl().params().range())
676 << "methods must have a self argument");
677 }
678 method.setSchema(emitDef(def, self, graph->block()));
679
680 // At this point, we might have received a graph that is compiled with
681 // old operator schemas that might not exist in the system anymore.
682 // Therefore, we replace such ops with its' valid upgrader.
683 ReplaceOldOperatorsWithUpgraders(graph);
684
685 // NB ORDERING: SSA conversion has to occur before
686 // lifting of closures and forks, this way closures are converted
687 // to SSA while part of their original graph, and closures are ready to
688 // be inlined into forked closures
689 ConvertToSSA(graph);
690
691 // convert loops with an iter and body condition specified to
692 // python-recognize while loops. we do this so they can be exported,
693 // and run the pass early to avoid jitter. Like conversion to SSA,
694 // it only needs to run once.
695 CanonicalizeModifiedLoops(graph);
696
697 // Convert Ops to a Normalized Form
698 NormalizeOps(graph);
699
700 runCleanupPasses(graph);
701 }
702
703 private:
704 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
705 GraphFunction& method;
706 std::shared_ptr<Graph> graph;
707 ResolverPtr resolver;
708 std::unordered_map<int64_t, Value*, std::hash<int64_t>> integral_constants;
709 std::unordered_map<double, Value*, std::hash<double>> fp_constants;
710 std::unordered_map<
711 c10::complex<double>,
712 Value*,
713 c10::hash<c10::complex<double>>>
714 complex_constants;
715 std::unordered_set<Block*> exit_blocks;
716 ScriptTypeParser typeParser_;
717 LoopStatus loop_status_ = LoopStatus::NOT_IN_LOOP;
718
719 // Singly-linked list of environments. This top element contains a member
720 // `next` that points to the most immediate enclosing scope's value.
721 std::shared_ptr<Environment> environment_stack;
722 std::vector<DefContext> def_stack_;
723 size_t temp_name_count_ = 0;
createTempNametorch::jit::to_ir724 std::string createTempName(const std::string& prefix) {
725 return prefix + std::to_string(temp_name_count_++);
726 }
727
pushFrametorch::jit::to_ir728 void pushFrame(Block* b, bool starts_def = false) {
729 if (starts_def) {
730 def_stack_.emplace_back();
731 }
732 environment_stack =
733 std::make_shared<Environment>(method, resolver, b, environment_stack);
734 }
popFrametorch::jit::to_ir735 std::shared_ptr<Environment> popFrame(bool ends_def = false) {
736 auto old_frame = environment_stack;
737 environment_stack = environment_stack->next;
738 if (ends_def) {
739 def_stack_.pop_back();
740 }
741 return old_frame;
742 }
743
744 // If the graph might not return, add an implicit None return at the end
handleMaybeNoReturntorch::jit::to_ir745 void handleMaybeNoReturn(const Def& def, Block* block) {
746 auto decl_ret = def_stack_.back().declared_return_type_;
747 if (exit_blocks.count(block) == 0) {
748 auto decl_ret = def_stack_.back().declared_return_type_;
749 if (decl_ret && decl_ret != NoneType::get()) {
750 throw(
751 ErrorReport(def.range())
752 << "Function was not annotated as having type None, but does not "
753 << "return along all paths");
754 }
755 WithInsertPoint b(*block->nodes().end());
756 emitReturn(Return::create(
757 def.range(), Expr(Compound::create(TK_NONE, def.range(), {}))));
758 } else {
759 // if we haven't seen any return statements, but the graph block exits
760 // (the function always throws) then we accept the declared return type if
761 // it exists or set it to none
762 if (def_stack_.back().merged_return_type_ == nullptr) {
763 def_stack_.back().merged_return_type_ =
764 decl_ret != nullptr ? decl_ret : NoneType::get();
765 }
766 }
767 }
768
emitDeftorch::jit::to_ir769 FunctionSchema emitDef(const Def& def, const Self* self, Block* block) {
770 auto schema = typeParser_.parseSchemaFromDef(def, bool(self));
771 // TODO need guards on init returning none
772 if (schema.returns().size() == 1) {
773 def_stack_.back().declared_return_type_ = schema.returns().at(0).type();
774 }
775 std::vector<Argument> arguments =
776 emitFormalArguments(def, self, schema, block);
777
778 // body
779 auto stmts_list = def.statements();
780 emitStatements(stmts_list.begin(), stmts_list.end());
781 handleMaybeNoReturn(def, block);
782 std::vector<Argument> returns = {emitOutput(def.range(), schema, block)};
783 return {def.name().name(), "", std::move(arguments), std::move(returns)};
784 }
785
786 // see [setstate type]
getTypeForSetStateArgtorch::jit::to_ir787 static TypePtr getTypeForSetStateArg(const Def& def, const Self* self) {
788 TORCH_CHECK(self, "Expected __setstate__ to have a `self` argument");
789 auto getstate = self->getClassType()->findMethod("__getstate__");
790 if (!getstate) {
791 throw(
792 ErrorReport(def.range())
793 << "`__setstate__` defined but not `__getstate__`. "
794 << "You must have both defined on a ScriptModule "
795 << "to customize serialization.\n"
796 << "Did you forget to use `@torch.jit.export`?");
797 }
798 getstate->ensure_defined();
799 return self->getClassType()
800 ->getMethod("__getstate__")
801 .getSchema()
802 .returns()
803 .at(0)
804 .type();
805 }
806
807 // see [setstate type]
shouldDeriveSetStateTypetorch::jit::to_ir808 static bool shouldDeriveSetStateType(
809 const Def& def,
810 const FunctionSchema& schema) {
811 const bool noTypeAnnotations = std::all_of(
812 schema.arguments().begin(),
813 schema.arguments().end(),
814 [](const Argument& arg) { return arg.is_inferred_type(); });
815
816 bool shouldInfer = def.name().name() == "__setstate__" && noTypeAnnotations;
817 if (!shouldInfer) {
818 return false;
819 }
820
821 // Do some additional basic validation that the __setstate__ func is
822 // well-formed
823 TORCH_INTERNAL_ASSERT(def.name().name() == "__setstate__");
824 const auto numDeclParams = def.decl().params().size();
825 if (numDeclParams != 2) {
826 throw(
827 ErrorReport(def.range())
828 << "Expected 2 arguments for `__setstate__`, got: " << numDeclParams);
829 }
830 return true;
831 }
832
emitFormalArgumentstorch::jit::to_ir833 std::vector<Argument> emitFormalArguments(
834 const Def& def,
835 const Self* self,
836 const FunctionSchema& schema,
837 Block* block) {
838 std::vector<Argument> arguments; // for schema
839 // inputs
840 auto it = def.decl().params().begin();
841 auto end = def.decl().params().end();
842 auto expected_annotation_size = def.decl().params().size();
843 if (self) {
844 expected_annotation_size--;
845 }
846 if (schema.arguments().size() != expected_annotation_size) {
847 throw(
848 ErrorReport(def.decl().params().range())
849 << "Number of type annotations for"
850 << " function parameters (" << schema.arguments().size() << ")"
851 << " does not match the number of parameters on the function ("
852 << expected_annotation_size << ")!");
853 }
854
855 if (self) {
856 AT_ASSERT(it != end);
857 const auto& name = (*it).ident().name();
858 Value* new_input = block->addInput()->setDebugName(name);
859 environment_stack->setSugaredVar(
860 (*it).ident().range(),
861 name,
862 self->makeSugared(new_input),
863 /*annotated_type=*/nullptr);
864 arguments.emplace_back(name, new_input->type());
865 ++it;
866 }
867
868 // [setstate type]
869 // __setstate__ is special, because if the user leaves it un-annotated we
870 // will derive the type for `state` from the output type of __getstate__.
871 // This is necessary so that we can allow submodules to appear in `state`.
872 bool shouldDeriveType = shouldDeriveSetStateType(def, schema);
873 size_t arg_annotation_idx = 0;
874 for (; it != end; ++it) {
875 auto& name = (*it).ident().name();
876 // Add the input to the graph
877 Value* new_input = block->addInput();
878 if (meaningfulName(name)) {
879 new_input->setDebugName(name);
880 }
881 // Record the type for the schema and set the Type on the Value*
882 auto arg = schema.arguments().at(arg_annotation_idx++);
883 if (shouldDeriveType) {
884 TORCH_INTERNAL_ASSERT(schema.arguments().size() == 1);
885 const auto& inferredStateType = getTypeForSetStateArg(def, self);
886 arg = arg.cloneWithType(inferredStateType);
887 }
888
889 arguments.push_back(arg);
890 new_input->setType(arguments.back().type());
891
892 // NB: set type of new_input before setVar call so the Store is
893 // typed appropriately
894 environment_stack->setVar((*it).ident().range(), name, new_input);
895 }
896 return arguments;
897 }
898
emitOutputtorch::jit::to_ir899 Argument emitOutput(
900 const SourceRange& range,
901 const FunctionSchema& schema,
902 Block* block) {
903 // handleMaybeNoReturn ensures that merged_return_type_ is always set
904 auto ret_type = def_stack_.back().merged_return_type_;
905 TORCH_INTERNAL_ASSERT(ret_type);
906
907 // in the ConvertToSSA pass, prim::ReturnStmts are lowered so that the
908 // correct return value is set. Until then, we have a correctly-typed
909 // placeholder return value. This is needed so that closures & graphs
910 // are correctly typed.
911 auto placeholder_return =
912 graph->insertNode(graph->createUninitialized(ret_type))->output();
913 block->registerOutput(placeholder_return);
914 return Argument("", def_stack_.back().merged_return_type_);
915 }
916
emitStatementstorch::jit::to_ir917 void emitStatements(const List<Stmt>& statements) {
918 return emitStatements(statements.begin(), statements.end());
919 }
920
921 // XXX: Right now closures are not generically implemented and are only used
922 // as an intermediate form for special tasks, like defining gradients or
923 // forked functions.
924 //
925 // There are several unfinished aspects that make them unusable generally
926 // 1. We do not have a type, ivalue, operator to represent prim::Closure, so
927 // closure_node has type None
928 // 2. There is no export logic for it yet, so it cannot be
929 // exported/python_printed
930 // 3. There is nothing preventing the assignment of already existing variables
931 // inside the closures
932 // the changes to those variables will just get forgotten.
933 // 4. There is no parsing support in frontend.py, this is intentional since it
934 // prevents people from accidentally using this feature.
935 //
936 // This function leaves in the graph something like:
937 //
938 // %2 : None = prim::Closure()
939 // block0():
940 // %1 : Tensor = prim::DoSomething(%0)
941 // -> (%1)
942 //
943 // A separate pass is required to erase this closure and replace it with
944 // something actually executable (see liftClosure and inlineForkedClosure).
emitClosuretorch::jit::to_ir945 std::shared_ptr<ClosureValue> emitClosure(
946 const std::function<void(Block*)>& emit_body) {
947 Node* closure_node = graph->insertNode(graph->create(prim::Closure, 1));
948 // it is not a real thing yet, so just say the type is None
949 closure_node->output()->setType(NoneType::get());
950 Block* block = closure_node->addBlock();
951 WithLoopStatus loop_guard(&loop_status_, LoopStatus::NOT_IN_LOOP);
952 {
953 WithInsertPoint guard(block);
954 pushFrame(block, /*starts_def=*/true);
955 emit_body(block);
956 popFrame(/*ends_def=*/true);
957 }
958 return std::make_shared<ClosureValue>(closure_node->output());
959 }
960
emitClosuretorch::jit::to_ir961 void emitClosure(const Def& def) {
962 // invoked once the closure block is set as the environment
963 auto emit_body = [&](Block* closure_block) {
964 emitDef(
965 def,
966 nullptr,
967 closure_block); // ignore schema return, we just wont use it for now
968 // since we never create a Method for the closure
969 };
970 auto closure_value = emitClosure(emit_body);
971 environment_stack->setSugaredVar(
972 def.name().range(),
973 def.name().name(),
974 closure_value,
975 /*annotated_type=*/nullptr);
976 }
977
checkBreakContinuetorch::jit::to_ir978 void checkBreakContinue(
979 const SourceRange& loc,
980 const std::string& stmt_name) {
981 if (loop_status_ == LoopStatus::NOT_IN_LOOP) {
982 throw(
983 ErrorReport(loc) << "SyntaxError: '" << stmt_name << "'"
984 << " outside loop");
985 } else if (loop_status_ == LoopStatus::IN_UNROLLED_LOOP) {
986 throw(
987 ErrorReport(loc)
988 << "Because we emit iteration over modulelists or tuples as "
989 "unrolled loops, we do not support break or continue inside the body of these loops");
990 }
991 }
992
emitBreaktorch::jit::to_ir993 void emitBreak(const Break& stmt) {
994 checkBreakContinue(stmt.range(), "break");
995 auto break_node =
996 graph->create(prim::BreakStmt, {}, 0)->setSourceRange(stmt.range());
997 graph->insertNode(break_node);
998 }
999
emitContinuetorch::jit::to_ir1000 void emitContinue(const Continue& stmt) {
1001 checkBreakContinue(stmt.range(), "continue");
1002 auto continue_node =
1003 graph->create(prim::ContinueStmt, {}, 0)->setSourceRange(stmt.range());
1004 graph->insertNode(continue_node);
1005 }
1006
emitDeletetorch::jit::to_ir1007 void emitDelete(const Delete& stmt) {
1008 for (const auto& target : stmt.targets()) {
1009 if (target.kind() == TK_SUBSCRIPT) {
1010 Subscript subscript(target);
1011 const List<Expr>& subscript_exprs = subscript.subscript_exprs();
1012 if (subscript_exprs[0].kind() == TK_SLICE_EXPR) {
1013 throw(
1014 ErrorReport(target.range())
1015 << "del statements only support deletion at a single index, "
1016 "slicing is not supported"
1017 " (see https://github.com/pytorch/pytorch/issues/31430)");
1018 }
1019 const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1);
1020 const SourceRange& val_range = subscript.value().range();
1021 Value* idx = emitExpr(subscript_exprs[0]);
1022 Value* val = sv->asValue(val_range, method);
1023
1024 // If val is a class instance, this is a method call to a type-specific
1025 // implementation of del defined in a __delitem__ method.
1026 if (auto cls = val->type()->cast<ClassType>()) {
1027 if (!cls->findMethod("__delitem__")) {
1028 throw(
1029 ErrorReport(target.range())
1030 << "Class does not define __delitem__");
1031 }
1032
1033 // Use MethodValue to call the method to handle recursion.
1034 MethodValue(val, "__delitem__")
1035 .call(stmt.range(), method, {idx}, {}, 0);
1036 } else {
1037 auto node = graph->create(aten::Delete, {val, idx}, 0)
1038 ->setSourceRange(target.range());
1039 graph->insertNode(node);
1040 }
1041 } else if (target.kind() == TK_VAR) {
1042 Var var(target);
1043 environment_stack->removeVar(var.name(), /*check_if_removed=*/true);
1044 } else {
1045 throw(
1046 ErrorReport(target.range())
1047 << "del statements are only supported for deleting"
1048 " list and dict items and variables");
1049 }
1050 }
1051 }
1052
emitReturntorch::jit::to_ir1053 void emitReturn(const Return& stmt) {
1054 TypePtr declared_return_type =
1055 def_stack_.back().declared_return_type_; // nullptr if not annotated
1056 auto actual_return = emitExpr(stmt.expr(), declared_return_type);
1057
1058 // result type is annotated, every return must convert to that type
1059 if (declared_return_type) {
1060 // this guard skips implicit conversion from None -> Tensor for the return
1061 // type. otherwise forgetting a return a function returning a tensor will
1062 // cause a None to be converted to a tensor.
1063 if (!(actual_return->type()->isSubtypeOf(*TensorType::get()) &&
1064 actual_return->type()->isSubtypeOf(*NoneType::get()))) {
1065 actual_return = tryConvertToType(
1066 stmt.range(),
1067 *graph,
1068 declared_return_type,
1069 actual_return,
1070 /*allow_conversions=*/true);
1071 }
1072 if (!actual_return->type()->isSubtypeOf(*declared_return_type)) {
1073 throw(
1074 ErrorReport(stmt.range())
1075 << "Return value was annotated as having type "
1076 << declared_return_type->repr_str() << " but is actually of type "
1077 << actual_return->type()->repr_str());
1078 }
1079 } else {
1080 declared_return_type = def_stack_.back().merged_return_type_;
1081 if (!declared_return_type) {
1082 declared_return_type = actual_return->type();
1083 }
1084 auto merged_return_type =
1085 unifyTypes(declared_return_type, actual_return->type());
1086 if (!merged_return_type) {
1087 throw(
1088 ErrorReport(stmt.range())
1089 << "Previous return statement returned a value of type "
1090 << declared_return_type->repr_str()
1091 << " but this return statement returns a value of type "
1092 << actual_return->type()->repr_str());
1093 }
1094 declared_return_type = merged_return_type.value();
1095 }
1096 AT_ASSERT(declared_return_type);
1097
1098 def_stack_.back().merged_return_type_ = declared_return_type;
1099
1100 // If the annotated return type is Any and the result type is not Any,
1101 // cast the result to Any to facilitate type unification between return
1102 // statements on different code paths (e.g. different branches of an if,
1103 // body and containing scope of a loop).
1104 if (declared_return_type == AnyType::get() &&
1105 actual_return->type() != AnyType::get()) {
1106 actual_return =
1107 graph->insertUncheckedCast(actual_return, declared_return_type);
1108 }
1109
1110 graph->insertNode(graph->create(prim::ReturnStmt, {actual_return}, 0));
1111 exit_blocks.insert(environment_stack->block());
1112 }
1113
emitStatementstorch::jit::to_ir1114 void emitStatements(
1115 List<Stmt>::const_iterator begin,
1116 List<Stmt>::const_iterator end) {
1117 for (; begin != end; ++begin) {
1118 auto stmt = *begin;
1119 ErrorReport::CallStack::update_pending_range(stmt.range());
1120 switch (stmt.kind()) {
1121 case TK_IF:
1122 emitIf(If(stmt));
1123 break;
1124 case TK_WHILE:
1125 emitWhile(While(stmt));
1126 break;
1127 case TK_FOR:
1128 emitFor(For(stmt));
1129 break;
1130 case TK_ASSIGN:
1131 emitAssignment(Assign(stmt));
1132 break;
1133 case TK_AUG_ASSIGN:
1134 emitAugAssignment(AugAssign(stmt));
1135 break;
1136 case TK_EXPR_STMT: {
1137 auto expr = ExprStmt(stmt).expr();
1138 emitSugaredExpr(expr, 0);
1139 } break;
1140 case TK_RAISE:
1141 emitRaise(Raise(stmt));
1142 break;
1143 case TK_ASSERT:
1144 emitAssert(Assert(stmt));
1145 break;
1146 case TK_RETURN: {
1147 emitReturn(Return(stmt));
1148 } break;
1149 case TK_CONTINUE: {
1150 emitContinue(Continue(stmt));
1151 } break;
1152 case TK_BREAK: {
1153 emitBreak(Break(stmt));
1154 } break;
1155 case TK_PASS:
1156 // Emit nothing for pass
1157 break;
1158 case TK_DEF:
1159 emitClosure(Def(stmt));
1160 break;
1161 case TK_DELETE:
1162 emitDelete(Delete(stmt));
1163 break;
1164 case TK_WITH:
1165 emitWith(With(stmt));
1166 break;
1167 default:
1168 throw(
1169 ErrorReport(stmt)
1170 << "Unrecognized statement kind " << kindToString(stmt.kind()));
1171 }
1172 // Found an exit statement in this block. The remaining statements aren't
1173 // reachable so we don't emit them.
1174 if (exit_blocks.count(environment_stack->block()))
1175 return;
1176 }
1177 }
1178
findIsNoneRefinementstorch::jit::to_ir1179 RefinementSet findIsNoneRefinements(
1180 const Expr& lhs,
1181 Value* lhs_value,
1182 const Expr& rhs,
1183 Value* rhs_value,
1184 int tok) {
1185 if (rhs.kind() != TK_NONE && lhs.kind() == TK_NONE) {
1186 // make 'None is var' into 'var is None'
1187 return findIsNoneRefinements(rhs, rhs_value, lhs, lhs_value, tok);
1188 }
1189 if (rhs.kind() != TK_NONE || lhs.kind() != TK_VAR) {
1190 return {};
1191 }
1192 // statement must be var {is, is not} None
1193 const std::string& name = Var(lhs).name().name();
1194 // While it should in theory be possible to specialize
1195 // the `x is None` to know x has type NoneType, we have previously
1196 // not done this. Unfortunately, doing this will make the type None
1197 // propagate further in all loaded models. The handling of
1198 // unwrap_optional will fail in these cases since export did
1199 // not expect that the input would be none and an unannotated None.
1200 // To enable this, we need to (1) implement a real casting operator
1201 // annotated(T, X) that stays in the graph and does the cast
1202 // and (2) only enable this OPTIONAL_NONE when loading newer
1203 // graphs because it is incompatible with older graphs.
1204 // Refinement none(name, RefinementKind::OPTIONAL_NONE);
1205 if (const auto optional_type = lhs_value->type()->cast<OptionalType>()) {
1206 Refinement present(name, optional_type->getElementType());
1207 if (tok == TK_IS) {
1208 return RefinementSet({}, {present});
1209 } else { // TK_ISNOT
1210 return RefinementSet({present}, {});
1211 }
1212 }
1213 if (const auto union_type = lhs_value->type()->cast<UnionType>()) {
1214 std::vector<TypePtr> to_subtract{NoneType::get()};
1215 std::optional<TypePtr> remaining =
1216 union_type->subtractTypeSet(to_subtract);
1217 std::vector<Refinement> all_present;
1218 if (remaining) {
1219 Refinement present{name, *remaining};
1220 all_present.push_back(std::move(present));
1221 }
1222 if (tok == TK_IS) {
1223 return RefinementSet({}, all_present);
1224 } else { // TK_ISNOT
1225 return RefinementSet(all_present, {});
1226 }
1227 }
1228 return RefinementSet();
1229 }
1230
emitCondExprtorch::jit::to_ir1231 CondValue emitCondExpr(const Expr& expr) {
1232 switch (expr.kind()) {
1233 case TK_AND:
1234 case TK_OR: {
1235 auto binop = BinOp(expr);
1236 return emitShortCircuitLogical(
1237 binop.range(), binop.lhs(), binop.rhs(), expr.kind() == TK_OR);
1238 }
1239 case TK_NOT: {
1240 CondValue v = emitCondExpr(Expr(expr.tree()->trees()[0]));
1241 Value* result = emitBuiltinCall(
1242 expr.range(), *graph, aten::__not__, {v.value()}, {});
1243 std::optional<bool> static_if;
1244 if (v.staticIf()) {
1245 static_if = !*v.staticIf();
1246 }
1247 return CondValue(result, v.refinements().Not(), static_if);
1248 } break;
1249 case TK_IS:
1250 case TK_ISNOT: {
1251 // meta programming on AST for is/is not cases and emit branches base on
1252 auto cond_op = BinOp(expr);
1253 Value* lhs_val = emitExpr(cond_op.lhs());
1254 Value* rhs_val = emitExpr(cond_op.rhs());
1255
1256 auto lhs_none = canBeNone(lhs_val);
1257 auto rhs_none = canBeNone(rhs_val);
1258
1259 // Dispatch logic (A: ALWAYS, N: NEVER, M: MAYBE):
1260 //
1261 // AA, -> statically IS always holds, IS_NOT never holds
1262 // AN , NA-> statically IS_NOT always holds, IS never holds
1263 // MA, MM, MN, NM, NN, AM -> cannot prove anything statically
1264 bool its_is = expr.kind() == TK_IS;
1265 if (lhs_none == ALWAYS && rhs_none == ALWAYS) {
1266 return CondValue(*graph, expr.range(), its_is, {});
1267 } else if (
1268 (lhs_none == ALWAYS && rhs_none == NEVER) ||
1269 (lhs_none == NEVER && rhs_none == ALWAYS)) {
1270 // lhs_val/rhs_val with A/M: only emit never_none_branch
1271 return CondValue(*graph, expr.range(), !its_is, {});
1272 } else {
1273 auto kind = getNodeKind(expr.kind(), expr.get()->trees().size());
1274 Value* cond_value = emitBuiltinCall(
1275 expr.get()->range(),
1276 *method.graph(),
1277 kind,
1278 {lhs_val, rhs_val},
1279 {});
1280 auto refinements = RefinementSet(findIsNoneRefinements(
1281 cond_op.lhs(), lhs_val, cond_op.rhs(), rhs_val, expr.kind()));
1282 return CondValue(cond_value, refinements, std::nullopt);
1283 }
1284 } break;
1285 default: {
1286 if (expr.kind() == TK_APPLY) {
1287 auto apply = Apply(expr);
1288 auto callee = Apply(expr).callee();
1289 if (callee.kind() == TK_VAR) {
1290 if (Var(callee).name().name() == "isinstance") {
1291 checkApplyNumInputs(apply, 2);
1292 return emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
1293 }
1294 if (Var(callee).name().name() == "hasattr") {
1295 checkApplyNumInputs(apply, 2);
1296 return emitHasAttr(apply.inputs()[0], apply.inputs()[1]);
1297 }
1298 }
1299 auto sv = emitSugaredExpr(apply.callee(), 1);
1300 auto loc = apply.callee().range();
1301 if (auto special_form = dynamic_cast<SpecialFormValue*>(sv.get())) {
1302 if (special_form->form() == prim::isinstance) {
1303 checkApplyNumInputs(apply, 2);
1304 return emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
1305 }
1306 }
1307 }
1308 auto expr_out = emitToBool(expr.range(), emitExpr(expr));
1309 std::optional<bool> static_if = std::nullopt;
1310 auto kind = expr_out->node()->kind();
1311 if (kind == aten::is_scripting) {
1312 static_if = true;
1313 } else if (kind == aten::has_torch_function) {
1314 static_if = false;
1315 }
1316 // MetaCompile on boolean literals and constants
1317 if (auto maybe_ivalue = toIValue(expr_out)) {
1318 static_if = maybe_ivalue->toBool();
1319 }
1320 return CondValue(expr_out, RefinementSet({}), static_if);
1321 } break;
1322 }
1323 }
1324
emitSingleIfBranchtorch::jit::to_ir1325 std::shared_ptr<Environment> emitSingleIfBranch(
1326 Block* b,
1327 const List<Stmt>& branch,
1328 const RefinementSet& refinements) {
1329 pushFrame(b);
1330 WithInsertPoint guard(b);
1331 insertRefinements(branch.range(), refinements);
1332 emitStatements(branch);
1333 return popFrame();
1334 }
1335
createtorch::jit::to_ir1336 Node* create(Symbol kind, const SourceRange& loc, size_t n_outputs) {
1337 return graph->create(kind, n_outputs)->setSourceRange(loc);
1338 }
1339
emitTernaryIftorch::jit::to_ir1340 Value* emitTernaryIf(
1341 const TernaryIf& expr,
1342 const TypePtr& type_hint = nullptr) {
1343 CondValue cond_value = emitCondExpr(expr.cond());
1344 // If the cond expr is a static value, then we metacompile the `if`
1345 // statemement and only emit true or false branch
1346 if (cond_value.staticIf()) {
1347 if (*cond_value.staticIf()) {
1348 return emitExpr(expr.true_expr(), type_hint);
1349 } else {
1350 return emitExpr(expr.false_expr(), type_hint);
1351 }
1352 }
1353 auto true_expr = [&] { return emitExpr(expr.true_expr(), type_hint); };
1354 auto false_expr = [&] { return emitExpr(expr.false_expr(), type_hint); };
1355 return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
1356 }
1357
1358 template <class F1, class F2, class F3>
refineAndSetUnionTypeHintOrPopulateCandidatesVectortorch::jit::to_ir1359 void refineAndSetUnionTypeHintOrPopulateCandidatesVector(
1360 const TypePtr& type_hint,
1361 TypePtr* refined_type_hint_ptr,
1362 std::vector<TypePtr>* all_candidates,
1363 const std::string& match_repr,
1364 const Expr& src,
1365 const F1& type_match,
1366 const F2& do_if_match,
1367 const F3& do_if_anytype,
1368 bool is_dict_constructor = false) {
1369 if (auto union_type_hint = (*refined_type_hint_ptr)->cast<UnionType>()) {
1370 // `candidate_types` holds all List types that were in the Union
1371 // annotation
1372 std::vector<TypePtr> candidate_types;
1373
1374 std::copy_if(
1375 union_type_hint->containedTypes().begin(),
1376 union_type_hint->containedTypes().end(),
1377 std::back_inserter(candidate_types),
1378 [&](TypePtr type_ptr) { return type_match(type_ptr); });
1379
1380 if (!is_dict_constructor && candidate_types.empty()) {
1381 throw(
1382 ErrorReport(src)
1383 << "Expected an Union type annotation "
1384 << "with an inner " << match_repr << " type, but got "
1385 << (*refined_type_hint_ptr)->repr_str());
1386 } else if (candidate_types.size() == 1) {
1387 // The Union only had a single type of the container we want to
1388 // match, so we can unconditionally refine it to that type
1389 (*refined_type_hint_ptr) = candidate_types[0];
1390 } else {
1391 // We can't refine the Union yet, since it contains multiple
1392 // types of the container we want to match, but we do at least
1393 // have a list of possiblee types (e.g. `Union[List[int],
1394 // List[str], float, str]` -> candidates={List[int], List[str]})
1395 (*all_candidates) = std::move(candidate_types);
1396 }
1397 } else if (
1398 auto optional_type_hint =
1399 (*refined_type_hint_ptr)->cast<OptionalType>()) {
1400 (*refined_type_hint_ptr) = optional_type_hint->getElementType();
1401 }
1402
1403 // This case handles code like `dict([(x, y), (a, b)])` that would
1404 // otherwise fail the following error checks
1405 if (is_dict_constructor) {
1406 return;
1407 }
1408
1409 // If we had any annotation that was NOT a Union that can hold more
1410 // than one type of the container we want to match
1411 if (all_candidates->empty()) {
1412 if (type_match(*refined_type_hint_ptr)) {
1413 do_if_match();
1414 } else if ((*refined_type_hint_ptr)->kind() == AnyType::Kind) {
1415 do_if_anytype();
1416 } else {
1417 throw(
1418 ErrorReport(src) << "Expected an annotation of type " << match_repr
1419 << " but got " << type_hint->repr_str());
1420 }
1421 }
1422 }
1423
refineAndSetListTypeHintFromCandidatesVectortorch::jit::to_ir1424 void refineAndSetListTypeHintFromCandidatesVector(
1425 const std::vector<TypePtr>& all_candidates,
1426 const TypePtr& type_hint,
1427 TypePtr* refined_type_hint_ptr,
1428 const TypePtr& unified_elem_type,
1429 const Expr& src) {
1430 TypePtr greatest_elem_type = nullptr;
1431 std::for_each(
1432 all_candidates.begin(),
1433 all_candidates.end(),
1434 [&](const TypePtr& candidate) {
1435 auto candidate_elem_type =
1436 candidate->expect<ListType>()->getElementType();
1437 if (unified_elem_type->isSubtypeOf(candidate_elem_type)) {
1438 if (!greatest_elem_type) {
1439 greatest_elem_type = candidate_elem_type;
1440 } else {
1441 greatest_elem_type =
1442 *(unifyTypes(greatest_elem_type, candidate_elem_type));
1443 }
1444 }
1445 });
1446 if (!greatest_elem_type) {
1447 std::stringstream vector_repr;
1448 for (size_t i = 0; i < all_candidates.size(); ++i) {
1449 if (i > 0 && all_candidates.size() > 2) {
1450 vector_repr << ", ";
1451 }
1452 if (i != 0 && i == all_candidates.size() - 1) {
1453 vector_repr << " or ";
1454 }
1455 vector_repr << all_candidates[i]->repr_str();
1456 }
1457 throw(
1458 ErrorReport(src) << "Union type annotation `" << type_hint->repr_str()
1459 << "` can hold " << vector_repr.str()
1460 << ", but none of "
1461 << "those types match the types of the given list "
1462 << "elements, which were unified to "
1463 << unified_elem_type->repr_str());
1464 } else {
1465 (*refined_type_hint_ptr) = ListType::create(greatest_elem_type);
1466 ;
1467 }
1468 }
1469
refineAndSetDictTypeHintFromCandidatesVectortorch::jit::to_ir1470 void refineAndSetDictTypeHintFromCandidatesVector(
1471 const std::vector<TypePtr>& all_candidates,
1472 const TypePtr& type_hint,
1473 TypePtr* refined_type_hint_ptr,
1474 const TypePtr& known_key_type,
1475 const TypePtr& known_value_type,
1476 const Expr& src) {
1477 TypePtr candidate_key_type = nullptr;
1478 TypePtr candidate_value_type = nullptr;
1479 TypePtr candidate = nullptr;
1480
1481 for (const auto& current_candidate : all_candidates) {
1482 auto current_key_type =
1483 current_candidate->expect<DictType>()->getKeyType();
1484 auto current_value_type =
1485 current_candidate->expect<DictType>()->getValueType();
1486
1487 if (known_key_type->isSubtypeOf(current_key_type) &&
1488 known_value_type->isSubtypeOf(current_value_type)) {
1489 if (!candidate ||
1490 (candidate_key_type->isSubtypeOf(current_key_type) &&
1491 candidate_value_type->isSubtypeOf(current_value_type))) {
1492 candidate_key_type = current_key_type;
1493 candidate_value_type = current_value_type;
1494 candidate = current_candidate;
1495 }
1496 }
1497 }
1498
1499 if (!candidate) {
1500 std::stringstream vector_repr;
1501 for (size_t i = 0; i < all_candidates.size(); ++i) {
1502 if (i > 0 && all_candidates.size() > 2) {
1503 vector_repr << ", ";
1504 }
1505 if (i != 0 && i == all_candidates.size() - 1) {
1506 vector_repr << " or ";
1507 }
1508 vector_repr << all_candidates[i]->repr_str();
1509 }
1510 throw(
1511 ErrorReport(src) << "Union type annotation `" << type_hint->repr_str()
1512 << "` can hold " << vector_repr.str()
1513 << ", but none of "
1514 << "those dict types can hold the types of the given"
1515 << " keys and values, which were unified to Dict["
1516 << known_key_type->repr_str() << ", "
1517 << known_value_type->repr_str());
1518 } else {
1519 (*refined_type_hint_ptr) = candidate;
1520 }
1521 }
1522
emitListComprehensiontorch::jit::to_ir1523 Value* emitListComprehension(const ListComp& lc, const TypePtr& type_hint) {
1524 const auto loc = lc.range();
1525 const auto targets_list = List<Expr>::create(lc.range(), {lc.target()});
1526 const auto itrs = List<Expr>::create(lc.range(), {lc.iter()});
1527
1528 // If there is no type hint, and this is emitted over an iterable that is
1529 // unrolled and of length 0, then we emit a List of tensors
1530 Value* list_value = graph->insertNode(graph->create(prim::ListConstruct, 1))
1531 ->output()
1532 ->setType(ListType::ofTensors());
1533
1534 TypePtr refined_type_hint = type_hint;
1535 std::vector<TypePtr> all_candidates = {};
1536
1537 if (refined_type_hint) {
1538 auto do_if_type_match = [&]() { list_value->setType(refined_type_hint); };
1539
1540 auto type_match = [&](const TypePtr& t) {
1541 return t->isSubtypeOf(AnyListType::get());
1542 };
1543
1544 refineAndSetUnionTypeHintOrPopulateCandidatesVector(
1545 type_hint,
1546 &refined_type_hint,
1547 &all_candidates,
1548 "List",
1549 lc,
1550 type_match,
1551 do_if_type_match,
1552 do_if_type_match);
1553 }
1554
1555 bool seen_first_elem = false;
1556
1557 // A list comprehension introduces its own scope
1558 Node* n =
1559 graph->insertNode(create(prim::ComprehensionScope, lc.range(), 0));
1560 auto* comprehension_block = n->addBlock();
1561 pushFrame(comprehension_block);
1562 WithInsertPoint guard(comprehension_block);
1563 auto emit_body = [&]() {
1564 Value* out = emitExpr(lc.elt());
1565
1566 // If we didn't have a type annotation, the type of the list would
1567 // be set to `Tensor`. We don't want to unify this default type
1568 // with the actual elements in the list, so let the type begin as
1569 // the first element in the list
1570 if (!seen_first_elem) {
1571 list_value->setType(ListType::create(out->type()));
1572 seen_first_elem = true;
1573 }
1574
1575 const auto elem_type_hint =
1576 refined_type_hint && refined_type_hint->kind() == ListType::Kind
1577 ? refined_type_hint->cast<ListType>()->getElementType()
1578 : nullptr;
1579
1580 std::optional<TypePtr> unified_elem_type = unifyTypes(
1581 list_value->type()->expect<ListType>()->getElementType(),
1582 out->type(),
1583 /*default_to_union=*/true,
1584 elem_type_hint);
1585
1586 // Case: The list comprehension generated heterogenous values,
1587 // and we don't have a type hint to suggest that this is what the
1588 // user expected
1589 if (!type_hint && (*unified_elem_type)->isUnionType()) {
1590 TORCH_WARN(
1591 "List consists of heterogeneous types, which means",
1592 " that it has been typed as containing ",
1593 (*unified_elem_type)->repr_str(),
1594 ". To use any of the "
1595 "values in this List, it will be necessary to add an "
1596 "`assert isinstance` statement before first use to trigger "
1597 "type refinement. The first non-matching element was typed",
1598 " as ",
1599 out->type()->repr_str(),
1600 ", while the elements "
1601 " before it were ",
1602 list_value->type()
1603 ->expect<ListType>()
1604 ->getElementType()
1605 ->repr_str(),
1606 "\n",
1607 lc.range().str());
1608 }
1609
1610 // Case: We had an annotation that we were able to narrow down to
1611 // a single ListType, but the most recently generated element in
1612 // the list comprehension doesn't match that annotation
1613 if (all_candidates.empty() && refined_type_hint &&
1614 !(*unified_elem_type)
1615 ->isSubtypeOf(*refined_type_hint->expectRef<ListType>()
1616 .getElementType())) {
1617 throw(
1618 ErrorReport(lc)
1619 << "List type annotation `" << refined_type_hint->repr_str()
1620 << "` did not match the types of the given list elements,"
1621 << " which were unified to " << (*unified_elem_type)->repr_str());
1622 }
1623
1624 if (!all_candidates.empty()) {
1625 // If we had a Union type annotation that could hold more than
1626 // one different type of `List`
1627 refineAndSetListTypeHintFromCandidatesVector(
1628 all_candidates,
1629 type_hint,
1630 &refined_type_hint,
1631 *unified_elem_type,
1632 lc);
1633 } else if (!refined_type_hint) {
1634 refined_type_hint = ListType::create(*unified_elem_type);
1635 }
1636
1637 list_value->setType(refined_type_hint);
1638 out->setType(refined_type_hint->expect<ListType>()->getElementType());
1639
1640 NamedValue self = NamedValue(loc, "self", list_value);
1641 NamedValue input = NamedValue(loc, "", out);
1642 emitBuiltinCall(loc, *graph, aten::append, {input}, {}, self);
1643 };
1644 emitFor(targets_list, itrs, loc, emit_body);
1645 popFrame();
1646 return list_value;
1647 }
1648
emitDictComprehensiontorch::jit::to_ir1649 Value* emitDictComprehension(const DictComp& dc, const TypePtr& type_hint) {
1650 const auto loc = dc.range();
1651 const auto targets_list = List<Expr>::create(dc.range(), {dc.target()});
1652 const auto itrs = List<Expr>::create(dc.range(), {dc.iter()});
1653
1654 Value* dict_value =
1655 graph->insertNode(graph->create(prim::DictConstruct, 1))->output();
1656
1657 // Set the default type to be Dict[str, Tensor]
1658 dict_value->setType(DictType::create(StringType::get(), TensorType::get()));
1659
1660 TypePtr refined_type_hint = type_hint;
1661 TypePtr annotated_union_type =
1662 type_hint && type_hint->isUnionType() ? type_hint : nullptr;
1663
1664 std::vector<TypePtr> all_candidates = {};
1665
1666 if (refined_type_hint) {
1667 auto type_match = [&](const TypePtr& t) {
1668 return t->kind() == DictType::Kind;
1669 };
1670
1671 auto do_if_match = [&]() { dict_value->setType(refined_type_hint); };
1672
1673 refineAndSetUnionTypeHintOrPopulateCandidatesVector(
1674 type_hint,
1675 &refined_type_hint,
1676 &all_candidates,
1677 "Dict",
1678 dc,
1679 type_match,
1680 do_if_match,
1681 do_if_match);
1682 }
1683
1684 TypePtr first_generated_key_type = nullptr;
1685 TypePtr first_generated_value_type = nullptr;
1686
1687 // A dict comprehension introduces its own scope. No variable assigned
1688 // may leak into the rest of the graph
1689 Node* n =
1690 graph->insertNode(create(prim::ComprehensionScope, dc.range(), 0));
1691 auto* comprehension_block = n->addBlock();
1692 pushFrame(comprehension_block);
1693 WithInsertPoint guard(comprehension_block);
1694 auto emit_body = [&]() {
1695 auto k = emitExpr(dc.key());
1696 auto v = emitExpr(dc.value());
1697
1698 // If we didn't have a type annotation, the type of the dict would
1699 // be set to `(str, Tensor)`. We don't want to unify this default
1700 // type with the actual elements in the dict, so let the type
1701 // begin as the first element in the dict
1702 if (k->type()->kind() == UnionType::Kind) {
1703 throw(
1704 ErrorReport(dc)
1705 << "Dicts may only contain homogeneous keys, but the type of "
1706 << "the first generated key was " << k->type()->repr_str());
1707 } else if (
1708 first_generated_key_type && first_generated_key_type != k->type()) {
1709 // Values can be heterogenous, so we only need to check that the
1710 // key types are all the same
1711 throw(
1712 ErrorReport(dc)
1713 << "Dicts may only contain homogeneous keys. Expected "
1714 << "dict comprehension to generate type "
1715 << first_generated_key_type->repr_str() << ", but got "
1716 << k->type()->repr_str());
1717 } else {
1718 dict_value->setType(DictType::create(k->type(), v->type()));
1719 first_generated_key_type = k->type();
1720 first_generated_value_type = v->type();
1721 }
1722
1723 // If we had any annotation OTHER THAN a Union that can hold more
1724 // than one type of Dict
1725 if (refined_type_hint && all_candidates.empty()) {
1726 DictTypePtr dict_type_hint = refined_type_hint->expect<DictType>();
1727
1728 std::stringstream ss;
1729 std::stringstream err;
1730
1731 bool is_key_subtype =
1732 k->type()->isSubtypeOfExt(*dict_type_hint->getKeyType(), &ss);
1733
1734 if (!is_key_subtype) {
1735 err << "Dict type annotation `" << dict_type_hint->repr_str()
1736 << "` did not match the "
1737 << "type of an actual key type `" << k->type()->repr_str()
1738 << "`\n"
1739 << ss.str();
1740 }
1741
1742 ss.str(std::string());
1743 bool is_value_subtype =
1744 v->type()->isSubtypeOfExt(*dict_type_hint->getValueType(), &ss);
1745
1746 if (!is_value_subtype) {
1747 err << "Dict type annotation `" << dict_type_hint->repr_str()
1748 << "` did not match the "
1749 << "type of an actual value type `" << v->type()->repr_str()
1750 << "`\n"
1751 << ss.str();
1752 }
1753
1754 if (!is_key_subtype || !is_value_subtype) {
1755 throw(ErrorReport(dc) << err.str());
1756 }
1757 }
1758
1759 const TypePtr value_type_hint =
1760 refined_type_hint && refined_type_hint->kind() == DictType::Kind
1761 ? refined_type_hint->expect<DictType>()->getValueType()
1762 : nullptr;
1763
1764 std::optional<TypePtr> unified_value_type = unifyTypes(
1765 first_generated_value_type,
1766 v->type(),
1767 /*default_to_union=*/true,
1768 value_type_hint);
1769
1770 if (!type_hint && (*unified_value_type)->isUnionType()) {
1771 TORCH_WARN(
1772 "Dict values consist of heterogeneous types, which means",
1773 " that they have been typed as being ",
1774 (*unified_value_type)->repr_str(),
1775 ". To use any of the "
1776 "values in this dict, it will be necessary to add an "
1777 "`assert isinstance` statement before first use to trigger "
1778 "type refinement. The first non-matching element was typed",
1779 " as ",
1780 v->type()->repr_str(),
1781 ", while the elements "
1782 " before it were ",
1783 first_generated_value_type->repr_str(),
1784 "\n",
1785 dc.range().str());
1786 }
1787
1788 if (type_hint) {
1789 if (type_hint->kind() == DictType::Kind) {
1790 dict_value->setType(type_hint);
1791 k->setType(type_hint->expect<DictType>()->getKeyType());
1792 v->setType(type_hint->expect<DictType>()->getValueType());
1793 } else {
1794 if (!all_candidates.empty()) {
1795 refineAndSetDictTypeHintFromCandidatesVector(
1796 all_candidates,
1797 type_hint,
1798 &refined_type_hint,
1799 k->type(),
1800 *unified_value_type,
1801 dc);
1802 }
1803 dict_value->setType(refined_type_hint);
1804 k->setType(refined_type_hint->expect<DictType>()->getKeyType());
1805 v->setType(refined_type_hint->expect<DictType>()->getValueType());
1806 }
1807 } else {
1808 dict_value->setType(DictType::create(k->type(), *unified_value_type));
1809 }
1810
1811 NamedValue self = NamedValue(loc, "self", dict_value);
1812 NamedValue input_k = NamedValue(loc, "", k);
1813 NamedValue input_v = NamedValue(loc, "", v);
1814 emitBuiltinCall(
1815 loc, *graph, aten::_set_item, {self, input_k, input_v}, {});
1816 };
1817 emitFor(targets_list, itrs, loc, emit_body);
1818 popFrame();
1819
1820 if (annotated_union_type) {
1821 Node* n =
1822 graph->insertNode(graph->create(prim::unchecked_cast, {dict_value}));
1823 n->output()->setType(std::move(annotated_union_type));
1824 dict_value = n->output();
1825 }
1826
1827 return dict_value;
1828 }
1829
1830 // Insert subtyping refinements
insertRefinementstorch::jit::to_ir1831 void insertRefinements(const SourceRange& loc, const RefinementSet& ref) {
1832 for (const Refinement& r : ref.activeRefinements()) {
1833 Value* v = environment_stack->getVar(r.identifier(), loc);
1834 Value* new_v = graph->insertUncheckedCast(v, r.type());
1835 environment_stack->setVar(loc, r.identifier(), new_v);
1836 }
1837 }
1838
emitShortCircuitLogicaltorch::jit::to_ir1839 CondValue emitShortCircuitLogical(
1840 const SourceRange& loc,
1841 const Expr& first_expr,
1842 const Expr& second_expr,
1843 bool is_or) {
1844 CondValue lhs = emitCondExpr(first_expr);
1845 // if the continue expr in the short circuit is not evaluated,
1846 // than the const expression is False if the short circuit
1847 // is an `and` and True if the short circuit is an `or`.
1848 // `False and expr` -> False, `True or expr` -> True
1849 //
1850 // inserting it as a constant makes optimization easier
1851
1852 // if it's an OR the first expr is emitted in the true branch
1853 // and the second expr in the false branch, if it's an AND the opposite
1854 auto get_const_expr = [&] { return graph->insertConstant(is_or, loc); };
1855
1856 std::optional<CondValue> rhs;
1857 auto get_continue_expr = [&] {
1858 rhs = emitCondExpr(second_expr);
1859 return rhs->value();
1860 };
1861
1862 // if this is an OR, eval second expression if first expr is False
1863 // If this is an AND, eval second expression if first expr is True
1864 Value* new_result = nullptr;
1865 std::optional<RefinementSet> refinements;
1866 std::optional<bool> static_if;
1867 if (is_or) {
1868 new_result = emitIfExpr(loc, lhs, get_const_expr, get_continue_expr);
1869 refinements = lhs.refinements().Or(rhs->refinements());
1870 if ((lhs.staticIf() && *lhs.staticIf()) ||
1871 (rhs->staticIf() && *rhs->staticIf())) {
1872 static_if = true;
1873 } else if (lhs.staticIf() && rhs->staticIf()) {
1874 static_if = *lhs.staticIf() || *rhs->staticIf();
1875 }
1876 } else {
1877 new_result = emitIfExpr(loc, lhs, get_continue_expr, get_const_expr);
1878 refinements = lhs.refinements().And(rhs->refinements());
1879 if (((lhs.staticIf() && !*lhs.staticIf()) ||
1880 (rhs->staticIf() && !*rhs->staticIf()))) {
1881 static_if = false;
1882 } else if (lhs.staticIf() && rhs->staticIf()) {
1883 static_if = *lhs.staticIf() && *rhs->staticIf();
1884 }
1885 }
1886 return CondValue(new_result, std::move(*refinements), static_if);
1887 }
1888
emitIfExprtorch::jit::to_ir1889 Value* emitIfExpr(
1890 const SourceRange& range,
1891 const CondValue& cond_value,
1892 const std::function<Value*()>& true_expr,
1893 const std::function<Value*()>& false_expr) {
1894 Node* n = graph->insertNode(create(prim::If, range, 0));
1895 n->addInput(cond_value.value());
1896 auto* true_block = n->addBlock();
1897 auto* false_block = n->addBlock();
1898
1899 auto emit_if_expr = [this, &range](
1900 Block* b,
1901 const RefinementSet& refinements,
1902 const std::function<Value*()>& expr_value) {
1903 pushFrame(b);
1904 WithInsertPoint guard(b);
1905 insertRefinements(range, refinements);
1906 Value* out_val = expr_value();
1907 b->registerOutput(out_val);
1908 popFrame();
1909 };
1910
1911 emit_if_expr(true_block, cond_value.refinements(), true_expr);
1912 emit_if_expr(false_block, cond_value.refinements().Not(), false_expr);
1913
1914 auto true_type = true_block->outputs().at(0)->type();
1915 auto false_type = false_block->outputs().at(0)->type();
1916 auto unified = unifyTypes(true_type, false_type);
1917 if (!unified) {
1918 throw(
1919 ErrorReport(range)
1920 << "if-expression's true branch has type " << true_type->repr_str()
1921 << " but false branch has type " << false_type->repr_str());
1922 }
1923
1924 // Add op outputs
1925 auto expr_value = n->addOutput()->setType(*unified); // Resulting value
1926
1927 return expr_value;
1928 }
emitToBooltorch::jit::to_ir1929 Value* emitToBool(const SourceRange& loc, Value* v) {
1930 Value* out = nullptr;
1931 try {
1932 auto bool_cast = environment_stack->getSugaredVar("bool", loc);
1933 out = asSimple(bool_cast->call(loc, method, {v}, {}, 0));
1934 } catch (...) {
1935 throw(
1936 ErrorReport(loc) << "Could not cast value of type "
1937 << v->type()->repr_str() << " to bool");
1938 }
1939 if (!out) {
1940 throw(
1941 ErrorReport(loc) << "Could not cast value of type "
1942 << v->type()->repr_str() << " to bool");
1943 }
1944 // cast value not response for checking output type
1945 if (!out->type()->isSubtypeOf(*BoolType::get())) {
1946 throw(
1947 ErrorReport(loc)
1948 << "expected a bool expression for condition but found "
1949 << out->type()->repr_str());
1950 }
1951 return out;
1952 }
1953
emitIfElseBlockstorch::jit::to_ir1954 void emitIfElseBlocks(
1955 const SourceRange& loc,
1956 const CondValue& cond_value,
1957 const List<Stmt>& trueBranch,
1958 const List<Stmt>& falseBranch) {
1959 // this is a static if statement: that is, it contains a subset
1960 // of operators where we are willing to specialize the if statement
1961 // to be only the true or false branch when the condition is statically
1962 // known. This is used to meta-program modules, for instance, when a
1963 // submodule is absent, an is None check can be used to ensure the
1964 // accesses to the None check, which would error, are not compiled.
1965 if (cond_value.staticIf()) {
1966 if (*cond_value.staticIf()) {
1967 insertRefinements(loc, cond_value.refinements());
1968 emitStatements(trueBranch);
1969 } else {
1970 insertRefinements(loc, cond_value.refinements().Not());
1971 emitStatements(falseBranch);
1972 }
1973 return;
1974 }
1975
1976 Node* n = graph->insertNode(create(prim::If, loc, 0));
1977 n->addInput(cond_value.value());
1978 auto* true_block = n->addBlock();
1979 auto* false_block = n->addBlock();
1980
1981 // Emit both blocks once to get the union of all mutated values
1982 auto save_true =
1983 emitSingleIfBranch(true_block, trueBranch, cond_value.refinements());
1984 auto save_false = emitSingleIfBranch(
1985 false_block, falseBranch, cond_value.refinements().Not());
1986
1987 bool true_exits = exit_blocks.count(true_block);
1988 bool false_exits = exit_blocks.count(false_block);
1989 if (true_exits && false_exits) {
1990 exit_blocks.insert(n->owningBlock());
1991 }
1992
1993 // In python, every variable assigned in an if statement escapes
1994 // the scope of the if statement (all variables are scoped to the function).
1995 // Script is a subset of python: we consider variables to be in scope
1996 // as long as there is a definition of the variable along all paths
1997 // through the if statement
1998 // ----
1999 // if ...:
2000 // a =
2001 // else:
2002 // ...
2003 // ... = a # error, a is not defined along all paths
2004 // ----
2005 // if ...:
2006 // a =
2007 // else:
2008 // a =
2009 // ... = a # OK, a is defined along all paths
2010 // ----
2011 // a = ...
2012 // if ...:
2013 // a =
2014 // ... = a # OK, a is defined along all paths
2015 // if ...:
2016 // a =
2017 // else:
2018 // return
2019 // ... = a # OK, a is always defined
2020
2021 // ordered set, because we want deterministic graph output
2022 std::set<std::string> mutated_variables;
2023
2024 // When we access either the true or false environment,
2025 // we need to set the insertion point so the prim::Load is inserted
2026 // into the right block.
2027 // if var is only defined in one branch save error in case it's used later
2028 for (auto& v : save_true->definedVariables()) {
2029 {
2030 WithInsertPoint insert(false_block);
2031 if (save_false->findInAnyFrame(v) || false_exits) {
2032 mutated_variables.insert(v);
2033 } else {
2034 if (reportSourceLocation(loc.source()->size())) {
2035 ErrorReport error(loc);
2036 environment_stack->setVariableTypeError(v, [=]() -> std::string {
2037 error << v << " is not defined in the false branch";
2038 return error.what();
2039 });
2040 } else {
2041 environment_stack->setVariableTypeError(v, [=]() -> std::string {
2042 std::stringstream ss;
2043 ss << v << " is not defined in the false branch. "
2044 << "The source info is eliminated due to the source file is too large. "
2045 << "To get it back, please set PYTORCH_JIT_ENABLE_LARGE_SOURCE_LOCATION=1 "
2046 << "as env var";
2047 return ss.str();
2048 });
2049 }
2050 }
2051 }
2052 }
2053 for (auto& v : save_false->definedVariables()) {
2054 {
2055 WithInsertPoint insert(true_block);
2056 if (save_true->findInAnyFrame(v) || true_exits) {
2057 mutated_variables.insert(v);
2058 } else {
2059 if (reportSourceLocation(loc.source()->size())) {
2060 ErrorReport error(loc);
2061 environment_stack->setVariableTypeError(v, [=]() -> std::string {
2062 error << v << " is not defined in the true branch";
2063 return error.what();
2064 });
2065 } else {
2066 environment_stack->setVariableTypeError(v, [=]() -> std::string {
2067 std::stringstream ss;
2068 ss << v << " is not defined in the false branch. "
2069 << "The source info is eliminated due to the source file is too large. "
2070 << "To get it back, please set PYTORCH_JIT_ENABLE_LARGE_SOURCE_LOCATION=1 "
2071 << "as env var";
2072 return ss.str();
2073 });
2074 }
2075 }
2076 }
2077 }
2078
2079 // Register outputs in each block
2080 for (const auto& x : mutated_variables) {
2081 Value* tv = nullptr;
2082 Value* fv = nullptr;
2083
2084 {
2085 WithInsertPoint insert(true_block);
2086 if (!true_exits) {
2087 tv = save_true->getVar(x, loc);
2088 }
2089 }
2090 {
2091 WithInsertPoint insert(false_block);
2092 if (!false_exits) {
2093 fv = save_false->getVar(x, loc);
2094 }
2095 }
2096
2097 // if both branches exit don't emit any variables
2098 // if one branch exits then we allow the all variables in the other branch
2099 // to escape scope since they are well-defined
2100 if (true_exits && false_exits) {
2101 continue;
2102 } else if (true_exits) {
2103 tv = graph->createUninitialized(fv->type())
2104 ->insertBefore(true_block->return_node())
2105 ->output();
2106 graph->createStore(x, tv)->insertBefore(true_block->return_node());
2107 } else if (false_exits) {
2108 fv = graph->createUninitialized(tv->type())
2109 ->insertBefore(false_block->return_node())
2110 ->output();
2111 graph->createStore(x, fv)->insertBefore(false_block->return_node());
2112 }
2113
2114 SugaredValuePtr maybe_sugared_x = environment_stack->findInAnyFrame(x);
2115 TypePtr full_type = nullptr;
2116 if (maybe_sugared_x) {
2117 Value* maybe_simple = asSimple(maybe_sugared_x);
2118 if (maybe_simple) {
2119 full_type = maybe_simple->type();
2120 }
2121 }
2122
2123 // Try to unify the types. If we found a type annotation earlier
2124 // in the environment, and if that type annotation is some form
2125 // of union, then we need to tell `unifyTypes` not to throw an
2126 // error if the branched return types we found are heterogenous
2127 bool default_to_union = full_type &&
2128 (full_type->kind() == UnionType::Kind ||
2129 full_type->kind() == OptionalType::Kind ||
2130 full_type->kind() == NumberType::Kind);
2131 auto unified = unifyTypes(
2132 tv->type(), fv->type(), /*default_to_union=*/default_to_union);
2133
2134 // We allow variables to be set to different types in each branch
2135 // as long as that variable is not already in scope or if that
2136 // variable does not get used later. Here, we save the error so
2137 // that the error message will be more informative in the case
2138 // that is used later. When `a` is accessed in `(a + 1)`, the
2139 // error will get printed:
2140 // if cond:
2141 // a = 1
2142 // else:
2143 // a = tensor
2144 // b = a + 1
2145 //
2146 if (!unified) {
2147 ErrorReport error(loc);
2148 error << "Type mismatch: " << x << " is set to type "
2149 << tv->type()->repr_str() << " in the true branch"
2150 << " and type " << fv->type()->repr_str()
2151 << " in the false branch";
2152 if (save_true->findInParentFrame(x) ||
2153 save_false->findInParentFrame(x)) {
2154 throw ErrorReport(error);
2155 } else {
2156 environment_stack->setVariableTypeError(
2157 x, [=]() -> std::string { return error.what(); });
2158 continue;
2159 }
2160 }
2161 environment_stack->setType(x, *unified);
2162 }
2163 }
2164
emitHasAttrtorch::jit::to_ir2165 CondValue emitHasAttr(const Expr& objExpr, const Expr& attrExpr) {
2166 auto obj = emitSugaredExpr(objExpr, 1);
2167 if (attrExpr.kind() != TK_STRINGLITERAL) {
2168 throw(
2169 ErrorReport(attrExpr)
2170 << "hasattr's second argument must be a string literal");
2171 }
2172 const std::string& name = StringLiteral(attrExpr).text();
2173 const bool hasAttr = obj->hasAttr(objExpr.range(), method, name);
2174 return CondValue(*graph, objExpr.range(), hasAttr, {});
2175 }
2176
emitIsInstancetorch::jit::to_ir2177 CondValue emitIsInstance(const Expr& obj, const Expr& classinfo) {
2178 Value* lhs_val = emitExpr(obj);
2179 std::vector<TypePtr> lhs_types;
2180 std::vector<TypePtr> rhs_types;
2181
2182 std::function<void(const Expr&)> gather_rhs = [&](const Expr& expr) {
2183 if (expr.kind() == TK_TUPLE_LITERAL) {
2184 for (Expr e : TupleLiteral(expr).inputs()) {
2185 gather_rhs(e);
2186 }
2187 return;
2188 }
2189 TypePtr type = typeParser_.parseTypeFromExpr(expr);
2190 rhs_types.emplace_back(type);
2191 };
2192
2193 lhs_types.push_back(lhs_val->type());
2194 gather_rhs(classinfo);
2195
2196 standardizeVectorForUnion(&lhs_types);
2197 standardizeVectorForUnion(&rhs_types);
2198
2199 RefinementSet refinement;
2200
2201 TypePtr unified_true = nullptr;
2202 TypePtr unified_false = nullptr;
2203
2204 std::vector<TypePtr> isinstance_types;
2205 std::vector<TypePtr> not_isinstance_types;
2206
2207 std::vector<Refinement> true_refinements;
2208 std::vector<Refinement> false_refinements;
2209
2210 bool all_lhs_subtype_some_rhs = true;
2211
2212 // We can discard any rhs types that we know statically would be
2213 // impossible. For example, if we had:
2214 //
2215 // def fn(x: Optional[str]):
2216 // if isinstance(x, (List[str], str, int)):
2217 // ...
2218 //
2219 // then `x` would be `str` in the true branch and `None` in the
2220 // false branch, not `(List[str], str, int)` in the true branch
2221 // and `None` in the false branch
2222 for (const TypePtr& lhs_type : lhs_types) {
2223 if (lhs_type == AnyType::get()) {
2224 isinstance_types.insert(
2225 isinstance_types.end(), rhs_types.begin(), rhs_types.end());
2226 not_isinstance_types.emplace_back(AnyType::get());
2227 // Edge case: we can still say that all lhs types subtype some
2228 // rhs type if `lhs` is `Any` and `rhs` is `Any`
2229 if (isinstance_types.size() != 1 ||
2230 isinstance_types[0] != AnyType::get()) {
2231 all_lhs_subtype_some_rhs = false;
2232 }
2233 break;
2234 }
2235
2236 auto get_smaller_type = [&](const TypePtr& t1,
2237 const TypePtr& t2) -> TypePtr {
2238 if (t1->isSubtypeOf(*t2)) {
2239 return t1;
2240 } else if (t2->isSubtypeOf(*t1)) {
2241 return t2;
2242 } else {
2243 return nullptr;
2244 }
2245 };
2246
2247 TypePtr found_refinement = nullptr;
2248 for (const TypePtr& rhs_type : rhs_types) {
2249 TypePtr maybe_smaller_type = get_smaller_type(lhs_type, rhs_type);
2250 if (!maybe_smaller_type) {
2251 continue;
2252 } else if (*maybe_smaller_type == *lhs_type) {
2253 // Cover the case that we have something like
2254 // lhs = `List[str]` and rhs = `list`
2255 found_refinement = lhs_type;
2256 } else if (*maybe_smaller_type == *rhs_type) {
2257 // We want the narrowest possible type
2258 found_refinement = found_refinement
2259 ? *(unifyTypes(found_refinement, rhs_type))
2260 : rhs_type;
2261 }
2262 }
2263
2264 if (found_refinement) {
2265 if (*found_refinement == *lhs_type) {
2266 all_lhs_subtype_some_rhs &= true;
2267 }
2268 isinstance_types.push_back(found_refinement);
2269 } else {
2270 // If the lhs couldn't be a subtype of the rhs (or couldn't
2271 // be "refined" to itself, as in the `List[str]` and `list`
2272 // case above), then we add `lhs_type` to the false branch
2273 // refinements. This is because the type can still be itself
2274 // if the `isinstance` check is false
2275 not_isinstance_types.push_back(lhs_type);
2276 all_lhs_subtype_some_rhs = false;
2277 }
2278 }
2279
2280 // For use with `unifyTypeList`
2281 std::stringstream nowhere;
2282
2283 // Get a single type for the true and false branches
2284 if (!isinstance_types.empty()) {
2285 unified_true =
2286 *unifyTypeList(isinstance_types, nowhere, /*default_to_union=*/true);
2287 }
2288 if (obj.kind() == TK_VAR && unified_true) {
2289 std::string ident = Var(obj).name().name();
2290 true_refinements = {Refinement(ident, unified_true)};
2291 }
2292
2293 // Get a single type for the true and false branches
2294 if (!not_isinstance_types.empty()) {
2295 unified_false = *unifyTypeList(
2296 not_isinstance_types, nowhere, /*default_to_union=*/true);
2297 }
2298 if (obj.kind() == TK_VAR && unified_false) {
2299 std::string ident = Var(obj).name().name();
2300 false_refinements = {Refinement(ident, unified_false)};
2301 }
2302
2303 refinement = RefinementSet(true_refinements, false_refinements);
2304
2305 bool is_statically_false = isinstance_types.empty();
2306
2307 // If the statement is statically true
2308 if (all_lhs_subtype_some_rhs) {
2309 return CondValue(*graph, obj.range(), true, std::move(refinement));
2310 }
2311
2312 if (is_statically_false) {
2313 return CondValue(*graph, obj.range(), false, std::move(refinement));
2314 }
2315
2316 // check maybe true/false at runtime, need an actual op
2317 Value* result =
2318 graph->insertNode(graph->createIsInstance(lhs_val, rhs_types))
2319 ->output();
2320 return CondValue(result, std::move(refinement), std::nullopt);
2321 }
2322
emitIftorch::jit::to_ir2323 void emitIf(const If& stmt) {
2324 Expr cond = stmt.cond();
2325 CondValue cond_value = emitCondExpr(cond);
2326 emitIfElseBlocks(
2327 stmt.range(), cond_value, stmt.trueBranch(), stmt.falseBranch());
2328 }
2329
2330 // *********************** Loop Operators ************************************
2331 // Emits a loop operator with the form:
2332 // Loop(max_trip_count)
2333 // block0(loop_counter) {
2334 // <body>
2335 // }
2336 // block1 {
2337 // <loop condition>
2338 // -> (condition)
2339 // }
2340 // For loops will have an empty loop condition block with condition set to
2341 // true. In the convert to ssa pass, the loop condition will correctly
2342 // inlined. and inputs and outputs added so that the loop conforms to the
2343 // semantics specified at
2344 // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Loop
emitLoopCommontorch::jit::to_ir2345 void emitLoopCommon(
2346 const SourceRange& range,
2347 const std::function<void()>& emit_body,
2348 const SugaredValuePtr& iter_val,
2349 std::optional<List<Expr>> targets,
2350 std::optional<Expr> cond) {
2351 Value* max_trip_count_val = nullptr;
2352 if (iter_val != nullptr) {
2353 max_trip_count_val = iter_val->len(range, method);
2354 } else {
2355 max_trip_count_val = materializeConstant(
2356 std::numeric_limits<int64_t>::max(),
2357 *graph,
2358 range,
2359 integral_constants);
2360 }
2361
2362 Node* n = graph->insertNode(create(prim::Loop, range, 0));
2363 auto* body_block = n->addBlock();
2364 {
2365 Block* condition_block = n->addBlock();
2366 pushFrame(condition_block);
2367 Value* out = nullptr;
2368 if (cond) {
2369 WithInsertPoint insert(condition_block);
2370 out = emitToBool(cond.value().range(), emitExpr(cond.value()));
2371 } else {
2372 WithInsertPoint insert(n);
2373 out = graph->insertConstant(true, range);
2374 }
2375 condition_block->registerOutput(out);
2376 popFrame();
2377 }
2378 n->addInput(max_trip_count_val);
2379
2380 WithLoopStatus loop_guard(&loop_status_, LoopStatus::IN_LOOP);
2381 Value* trip_count =
2382 body_block->addInput()->setType(IntType::get()); // Iteration num
2383 {
2384 pushFrame(body_block);
2385 WithInsertPoint guard(body_block);
2386
2387 // if the FOR iters and targets are present, emit FOR target assignments
2388 if (iter_val != nullptr && targets) {
2389 Value* cur_elem = iter_val->getitem(range, method, trip_count)
2390 ->asValue(range, method);
2391 SugaredValuePtr sv = std::make_shared<SimpleValue>(cur_elem);
2392 List<Expr> target_exprs = targets.value();
2393 validateAssignLhsExpr(target_exprs, range);
2394
2395 // if target exprs are more than 1, it means iteration unpacking on LHS
2396 // we create Tuple literal to wrap those target exprs for assignments
2397 if (target_exprs.size() > 1) {
2398 Expr tl = TupleLiteral::create(range, target_exprs);
2399 target_exprs = List<Expr>::create(range, {tl});
2400 }
2401 emitExprsAssign(target_exprs, {sv}, range, /*n_binders=*/1);
2402 }
2403 emit_body();
2404 popFrame();
2405 }
2406 }
2407
emitUnrolledLooptorch::jit::to_ir2408 void emitUnrolledLoop(
2409 const SourceRange& loc,
2410 const std::function<void()>& emit_body,
2411 const SugaredValuePtr& iterable,
2412 const List<Expr>& targets) {
2413 auto static_len = iterable->staticLen();
2414 TORCH_INTERNAL_ASSERT(
2415 static_len, "Unrolled loop iter should have static length");
2416 int64_t len = *static_len;
2417 WithLoopStatus loop_guard(&loop_status_, LoopStatus::IN_UNROLLED_LOOP);
2418 // In order to support ModuleLists which return different types,
2419 // as with an nn.Sequential which has a module that returns a Dict and then
2420 // a module which returns a Tensor,
2421 // we do not push a new environment frame because if we did all intermediary
2422 // values would have to subtype the input type.
2423 for (const auto i : c10::irange(len)) {
2424 auto index =
2425 materializeConstant(i, *method.graph(), loc, integral_constants);
2426 auto sugared_value = iterable->getitem(loc, method, index);
2427 emitExprsAssign(
2428 targets, {sugared_value}, targets.range(), /*n_binders=*/1);
2429 emit_body();
2430 }
2431 }
2432
emitFortorch::jit::to_ir2433 void emitFor(
2434 const List<Expr>& targets,
2435 const List<Expr>& itrs,
2436 const SourceRange& loc,
2437 const std::function<void()>& emit_body) {
2438 if (itrs.size() != 1) {
2439 throw(ErrorReport(loc) << "List of iterables is not supported currently");
2440 }
2441
2442 // Emit loop information for builtinFunction values like range(), zip(),
2443 // enumerate() or SimpleValue like List, Tensor, Dict, etc.
2444 SugaredValuePtr sv = emitSugaredExpr(itrs[0], 1);
2445 SugaredValuePtr iterable = sv->iter(loc, method);
2446
2447 // We unroll the loop for iterables that contain ModuleLists so that we can
2448 // compile Heterogenous module lists.
2449 if (!iterable->shouldEmitUnrolled()) {
2450 emitLoopCommon(loc, emit_body, iterable, targets, {});
2451 } else {
2452 emitUnrolledLoop(loc, emit_body, iterable, targets);
2453 }
2454 }
2455
emitFortorch::jit::to_ir2456 void emitFor(const For& stmt) {
2457 auto emit_body = [&]() { emitStatements(stmt.body()); };
2458 emitFor(stmt.targets(), stmt.itrs(), stmt.range(), emit_body);
2459 }
2460
emitWhiletorch::jit::to_ir2461 void emitWhile(const While& stmt) {
2462 auto cond = stmt.cond();
2463 auto emit_body = [&]() { emitStatements(stmt.body()); };
2464 emitLoopCommon(stmt.range(), emit_body, nullptr, {}, cond);
2465 }
2466
emitWithtorch::jit::to_ir2467 void emitWith(const With& stmt) {
2468 auto targets = stmt.targets();
2469 // Keep a stack of entered objects so they can be exited
2470 // in the right order.
2471 std::stack<Value*> entered;
2472
2473 for (const auto& target : targets) {
2474 Expr e = target.target();
2475
2476 auto* rhs = emitExpr(e);
2477 auto* n = graph->insertNode(graph->create(prim::Enter, {rhs}));
2478 entered.push(rhs);
2479
2480 if (rhs->type()->kind() != TypeKind::ClassType) {
2481 throw(
2482 ErrorReport(e.range())
2483 << "With item expression must return an object");
2484 }
2485
2486 auto rhsClass = rhs->type()->expect<ClassType>();
2487 auto* enterMethod = rhsClass->findMethod("__enter__");
2488 auto* exitMethod = rhsClass->findMethod("__exit__");
2489
2490 if (!enterMethod || !exitMethod) {
2491 throw(
2492 ErrorReport(e.range())
2493 << "Object returned by with item expression does not define __enter__ and __exit__ methods");
2494 }
2495
2496 // Check the schema of __enter__.
2497 auto& enterSchema = enterMethod->getSchema();
2498 if (enterSchema.arguments().size() != 1) {
2499 throw(
2500 ErrorReport(e.range())
2501 << "__enter__ must have only one argument and one return value");
2502 }
2503
2504 // Check the schema of __exit__.
2505 auto& exitSchema = exitMethod->getSchema();
2506 if (exitSchema.arguments().size() != 4) {
2507 throw(ErrorReport(e.range()) << "__exit__ must have four arguments");
2508 } else {
2509 for (unsigned i = 1; i < 4; ++i) {
2510 if (exitSchema.arguments().at(i).type() != AnyType::get()) {
2511 throw(
2512 ErrorReport(e.range())
2513 << "argument " << i
2514 << " of __exit__ must have Any type; TorchScript does not currently support passing exception type, value, or traceback to the __exit__ function.");
2515 }
2516 }
2517 }
2518
2519 // Set the output of the enter node to be the return type of __enter__.
2520 n->output(0)->setType(enterSchema.returns().at(0).type());
2521
2522 // Set i = e.__enter__() so that references to i in the body of the with
2523 // will resolve correctly.
2524 if (target.var().present()) {
2525 Var i = target.var().get();
2526 environment_stack->setVar(i.range(), i.name().name(), n->output(0));
2527 }
2528 }
2529
2530 emitStatements(stmt.body());
2531
2532 // Insert all the corresponding prim::Exit nodes.
2533 while (!entered.empty()) {
2534 auto* input = entered.top();
2535 entered.pop();
2536 auto* n = graph->create(prim::Exit);
2537 graph->insertNode(n);
2538 n->addInput(input);
2539 }
2540 }
2541
2542 // Currently we do not support assigning exceptions to variables,
2543 // a = Exception("hi")
2544 // raise a
2545 //
2546 // We ignore the expression following raise
emitRaisetorch::jit::to_ir2547 void emitRaise(const Raise& raise) {
2548 auto sv = emitSugaredExpr(raise.expr(), 1);
2549 Value* error_message = nullptr;
2550 Value* qualified_class_name = nullptr;
2551
2552 if (auto exception_instance =
2553 std::dynamic_pointer_cast<ExceptionMessageValue>(sv)) {
2554 // The typical case, an instance of the exception class was thrown:
2555 // raise RuntimeError("error")
2556 error_message = exception_instance->getValue();
2557 qualified_class_name = exception_instance->getQualifiedClassName();
2558 } else if (
2559 auto exception_class = std::dynamic_pointer_cast<ExceptionValue>(sv)) {
2560 // A bare exception was thrown so add an empty message. e.g.
2561 // raise RuntimeError
2562 error_message = insertConstant(*graph, "", raise.range());
2563 } else {
2564 // The raise was not followed by an exception (i.e. it was something like
2565 // `raise "error"` instead of `raise RuntimeError("error")`)
2566 throw(
2567 ErrorReport(raise.range())
2568 << "exceptions must derive from BaseException");
2569 }
2570
2571 if (!error_message->type()->isSubtypeOf(*StringType::get())) {
2572 error_message = graph->insert(aten::str, {error_message});
2573 }
2574
2575 graph->insert(
2576 prim::RaiseException,
2577 {error_message, qualified_class_name},
2578 {},
2579 raise.range());
2580 exit_blocks.insert(environment_stack->block());
2581 }
2582
2583 // emit assserions as an if branch so that assertions will reuse the
2584 // message
emitAsserttorch::jit::to_ir2585 void emitAssert(const Assert& stmt) {
2586 CondValue cond_value = emitCondExpr(stmt.test());
2587 List<Stmt> true_branch = List<Stmt>::create(stmt.range(), {});
2588 // Create an `AssertionError("the_message")` call
2589 auto message = (stmt.msg().present())
2590 ? stmt.msg().get()
2591 : StringLiteral::create(stmt.range(), "");
2592 auto callee = Var::create(
2593 stmt.range(), Ident::create(stmt.range(), "AssertionError"));
2594 auto apply = Apply::create(
2595 stmt.range(),
2596 callee,
2597 List<Expr>::create(stmt.range(), {message}),
2598 List<Attribute>::create(stmt.range(), {}));
2599
2600 List<Stmt> false_branch =
2601 List<Stmt>::create(stmt.range(), {Raise::create(stmt.range(), apply)});
2602 emitIfElseBlocks(stmt.range(), cond_value, true_branch, false_branch);
2603 }
2604
2605 // Validate that the `lhs` Expr's in an assignment statement are valid. That
2606 // is:
2607 //
2608 // 1) All lhs Expr's are either Var, Tuple or Starred nodes
2609 // 2) There is at most one Starred node in the lhs Expr
2610 // 3) A Starred node can only appear when there is another non-Starred lhs
2611 // Expr. Concretely this means that `*abc = func()` is illegal. Unpacking
2612 // all outputs into a tuple is covered by `abc = func()`.
validateAssignLhsExprtorch::jit::to_ir2613 bool validateAssignLhsExpr(const List<Expr>& lhs, const SourceRange& r) {
2614 size_t num_normal_assign = 0;
2615 size_t num_starred = 0;
2616 for (const auto& assignee : lhs) {
2617 if (assignee.kind() == TK_VAR || assignee.kind() == TK_SUBSCRIPT ||
2618 assignee.kind() == TK_TUPLE_LITERAL || assignee.kind() == '.') {
2619 num_normal_assign++;
2620 } else if (assignee.kind() == TK_STARRED) {
2621 num_starred++;
2622 } else {
2623 throw(
2624 ErrorReport(assignee) << "lhs of assignment must be a variable, "
2625 << "subscript, or starred expression");
2626 }
2627 }
2628
2629 if (num_starred > 1) {
2630 throw(
2631 ErrorReport(r)
2632 << "Only one starred expression is allowed on the lhs");
2633 }
2634
2635 if (num_starred > 0 && num_normal_assign == 0) {
2636 throw(
2637 ErrorReport(r) << "A Starred expression may only appear on the "
2638 << "lhs within the presence of another non-starred"
2639 << " expression");
2640 }
2641
2642 return num_starred;
2643 }
2644
2645 // Get the appropriate builtin op for this augmented assignment
2646 // If the RHS is a tensor, return the corresponding ATen in-place op
2647 // If it's a list of scalars, then return the corresponding list augment op
getAugOptorch::jit::to_ir2648 Symbol getAugOp(const AugAssign& stmt, const TypePtr& type) {
2649 bool use_inplace_op = type->isSubtypeOf(*TensorType::get()) ||
2650 type->kind() == TypeKind::ListType;
2651 switch (stmt.aug_op()) {
2652 case '+':
2653 return use_inplace_op ? aten::add_ : aten::add;
2654 case '-':
2655 return use_inplace_op ? aten::sub_ : aten::sub;
2656 case '/':
2657 return use_inplace_op ? aten::div_ : aten::div;
2658 case '*':
2659 return use_inplace_op ? aten::mul_ : aten::mul;
2660 case '%':
2661 return use_inplace_op ? aten::fmod_ : aten::fmod;
2662 case '|':
2663 return use_inplace_op ? aten::bitwise_or : aten::__or__;
2664 case '&':
2665 return use_inplace_op ? aten::bitwise_and : aten::__and__;
2666 case '^':
2667 return use_inplace_op ? aten::bitwise_xor : aten::__xor__;
2668 case TK_LSHIFT:
2669 return use_inplace_op ? aten::__ilshift__ : aten::__lshift__;
2670 case TK_RSHIFT:
2671 return use_inplace_op ? aten::__irshift__ : aten::__rshift__;
2672 case TK_POW:
2673 return aten::pow;
2674 default:
2675 throw(
2676 ErrorReport(stmt)
2677 << "Unknown augmented assignment: " << kindToString(stmt.aug_op()));
2678 }
2679 }
2680
2681 // Get a pair of <in place magic method name, out of place magic method name>
2682 // since the out of place method is called if the in place method is not
2683 // present
getAugMagicMethodtorch::jit::to_ir2684 std::pair<std::string, std::string> getAugMagicMethod(const AugAssign& stmt) {
2685 switch (stmt.aug_op()) {
2686 case '+':
2687 return std::make_pair(std::string("__iadd__"), std::string("__add__"));
2688 case '-':
2689 return std::make_pair(std::string("__isub__"), std::string("__sub__"));
2690 case '/':
2691 return std::make_pair(
2692 std::string("__itruediv__"), std::string("__truediv__"));
2693 case '*':
2694 return std::make_pair(std::string("__imul__"), std::string("__mul__"));
2695 case '%':
2696 return std::make_pair(std::string("__imod__"), std::string("__mod__"));
2697 default:
2698 throw(
2699 ErrorReport(stmt)
2700 << "Unknown augmented assignment: " << kindToString(stmt.aug_op()));
2701 }
2702 }
2703
2704 // Emit nodes for augmented assignments like `+=`
emitAugAssignmenttorch::jit::to_ir2705 void emitAugAssignment(const AugAssign& stmt) {
2706 switch (stmt.lhs().kind()) {
2707 case TK_VAR: {
2708 emitAugAssignmentToVar(stmt);
2709 } break;
2710 case '.': {
2711 emitAugAssignmentToSelectVar(stmt);
2712 } break;
2713 case TK_SUBSCRIPT: {
2714 emitAugAssignmentToSubscript(stmt);
2715 } break;
2716 default:
2717 throw(
2718 ErrorReport(stmt.lhs())
2719 << "unexpected expression on "
2720 << "left-hand side of augmented assignment");
2721 }
2722 }
2723
2724 // This will be called when there is a class param or module buffer
2725 // mutation which make the LHS of the expr be a select expression
2726 //
2727 // Example like:
2728 // class A(Module):
2729 // def __init__():
2730 // self.register_buffer("running_var", torch.zeros(1))
2731 //
2732 // def forward():
2733 // self.num_batches += 1
emitAugAssignmentToSelectVartorch::jit::to_ir2734 void emitAugAssignmentToSelectVar(const AugAssign& stmt) {
2735 const auto lhs = Select(stmt.lhs());
2736 auto lhsSugaredVar = emitSugaredExpr(lhs.value(), 1);
2737 const auto lhsValue =
2738 lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())
2739 ->asValue(lhs.range(), method);
2740 auto result = emitAugAssignmentHelper(stmt, lhsValue);
2741 lhsSugaredVar->setAttr(stmt.range(), method, lhs.selector().name(), result);
2742 }
2743
emitAugAssignmentToVartorch::jit::to_ir2744 void emitAugAssignmentToVar(const AugAssign& stmt) {
2745 const auto lhs = Var(stmt.lhs());
2746 auto lhsValue = emitExpr(lhs);
2747 auto result = emitAugAssignmentHelper(stmt, lhsValue);
2748 environment_stack->setVar(lhs.range(), lhs.name().name(), result);
2749 }
2750
emitAugAssignmentHelpertorch::jit::to_ir2751 Value* emitAugAssignmentHelper(const AugAssign& stmt, Value* lhs) {
2752 if (lhs->type()->kind() == TypeKind::ClassType) {
2753 // Call `__iadd__` so updates happen in place on class types
2754 // https://docs.python.org/3/reference/datamodel.html#object.__iadd__
2755 std::string in_place_method_name;
2756 std::string out_of_place_method_name;
2757 std::tie(in_place_method_name, out_of_place_method_name) =
2758 getAugMagicMethod(stmt);
2759 const auto rhs = emitExpr(stmt.rhs());
2760
2761 // Determine whether to use __iadd__ or __add__ (use __add__ only if
2762 // __iadd__ is not present)
2763 auto type = lhs->type()->expect<ClassType>();
2764 std::string magic_method_name;
2765 if (type->findMethod(in_place_method_name)) {
2766 magic_method_name = in_place_method_name;
2767 } else if (type->findMethod(out_of_place_method_name)) {
2768 magic_method_name = out_of_place_method_name;
2769 } else {
2770 throw(
2771 ErrorReport(stmt.range())
2772 << "Cannot emit inplace op on " << type->repr_str()
2773 << " since it does not define an " << in_place_method_name << " or "
2774 << out_of_place_method_name << " method");
2775 }
2776
2777 // x += y is equivalent to x = x.__iadd__(y) or x = x.__add__(y) if
2778 // __iadd__ is not present
2779 return MethodValue(lhs, magic_method_name)
2780 .call(stmt.range(), method, {rhs}, {}, 0)
2781 ->asValue(stmt.range(), method);
2782 } else {
2783 const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()))
2784 .value(*method.graph());
2785 return emitBuiltinCall(
2786 stmt.range(),
2787 *method.graph(),
2788 getAugOp(stmt, lhs->type()),
2789 /*args=*/{lhs, rhs},
2790 /*kwargs=*/{},
2791 /*self=*/std::nullopt);
2792 }
2793 }
2794
emitAugAssignmentGenerictorch::jit::to_ir2795 void emitAugAssignmentGeneric(
2796 const AugAssign& stmt,
2797 const Subscript& lhs,
2798 Value* sliceable) {
2799 // Get the idx to augment
2800 const auto subscriptExprs = lhs.subscript_exprs();
2801 const TypePtr type = sliceable->type();
2802 if (subscriptExprs.size() != 1) {
2803 throw(
2804 ErrorReport(subscriptExprs)
2805 << "Sliced expression not yet supported for " << type->repr_str()
2806 << " augmented assignment. "
2807 << "File a bug if you want this");
2808 }
2809
2810 TypePtr elemType = nullptr;
2811 if (const ListTypePtr listType = type->cast<ListType>()) {
2812 elemType = listType->getElementType();
2813 } else if (const DictTypePtr dictType = type->cast<DictType>()) {
2814 elemType = dictType->getKeyType();
2815 }
2816
2817 if (elemType == nullptr) {
2818 throw(
2819 ErrorReport(lhs) << type->repr_str()
2820 << " does not support augmented assignment.");
2821 }
2822 const auto idxValue = emitExpr(subscriptExprs[0]);
2823 const auto containerArg =
2824 NamedValue(lhs.value().range(), type->str(), sliceable);
2825 const auto idxArg = NamedValue(subscriptExprs.range(), "idx", idxValue);
2826 const auto valueArg =
2827 NamedValue(stmt.rhs().range(), "value", emitExpr(stmt.rhs()));
2828
2829 const auto getItem = graph->insert(
2830 aten::__getitem__, {containerArg, idxArg}, {}, stmt.range());
2831 const auto augmentedItem = graph->insert(
2832 getAugOp(stmt, elemType), {getItem, valueArg}, {}, stmt.range());
2833 graph->insert(
2834 aten::_set_item,
2835 {containerArg, idxArg, augmentedItem},
2836 {},
2837 stmt.range());
2838 }
2839
emitAugAssignmentToSubscripttorch::jit::to_ir2840 void emitAugAssignmentToSubscript(const AugAssign& stmt) {
2841 // Process the base list value
2842 const auto lhs = Subscript(stmt.lhs());
2843 const auto sliceable = emitExpr(lhs.value());
2844
2845 if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
2846 // If it's a tensor, just fully evaluate the subscript operation and emit
2847 // an in-place assignment
2848 auto [sliced, tensorIndices] = emitIntAndSliceIndexing(
2849 lhs.range(), sliceable, lhs.subscript_exprs());
2850
2851 const auto slicedArg = NamedValue(stmt.lhs().range(), "self", sliced);
2852 const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
2853 if (tensorIndices.empty()) {
2854 // Common case: we only tried to index with int and slices. Emit the
2855 // correct augmented assignment op to the sliced value
2856 emitBuiltinCall(
2857 stmt.range(),
2858 *method.graph(),
2859 getAugOp(stmt, sliceable->type()),
2860 {rhs},
2861 {},
2862 slicedArg);
2863 } else {
2864 // Special case: we tried to do "advanced indexing". Lower this expr
2865 // into `index` and `index_put_` ops with tensordices of Tensor?[]
2866 const auto indices = graph
2867 ->insertNode(graph->createList(
2868 OptionalType::ofTensor(), tensorIndices))
2869 ->output();
2870 const auto indexed =
2871 graph->insert(aten::index, {slicedArg, indices}, {}, stmt.range());
2872 const auto augmented = emitBuiltinCall(
2873 stmt.range(),
2874 *method.graph(),
2875 getAugOp(stmt, sliceable->type()),
2876 {rhs},
2877 {},
2878 indexed);
2879 graph->insert(
2880 aten::index_put_,
2881 {slicedArg, indices, augmented},
2882 {},
2883 stmt.range());
2884 }
2885 } else {
2886 emitAugAssignmentGeneric(stmt, lhs, sliceable);
2887 }
2888 }
2889
emitValueToTensortorch::jit::to_ir2890 NamedValue emitValueToTensor(
2891 const NamedValue& value,
2892 const NamedValue& matchTypeOf) {
2893 // Add implicit conversion of int/float/complex/bool/number types to tensors
2894 // Used in emitSubscriptAssign to convert:
2895 // `tensor(...)[x] = 99` to `tensor(...)[x] = tensor(99)`
2896 // Mirrors the `valueToTensor` behavior in python_variable_indexing.cpp
2897 const auto kind = value.type()->kind();
2898 if (kind == c10::TypeKind::NumberType || kind == c10::TypeKind::IntType ||
2899 kind == c10::TypeKind::BoolType || kind == c10::TypeKind::FloatType ||
2900 kind == c10::TypeKind::ComplexType) {
2901 auto dtype = graph->insert(prim::dtype, {matchTypeOf}, {});
2902 auto device = graph->insert(prim::device, {matchTypeOf}, {});
2903 auto converted = graph->insert(
2904 aten::tensor,
2905 {value},
2906 {NamedValue("dtype", dtype), NamedValue("device", device)});
2907 return NamedValue(value.loc(), converted);
2908 }
2909
2910 return value;
2911 }
2912
2913 // Emit mutating assignments like `foo[0] = bar`
emitSubscriptAssigntorch::jit::to_ir2914 void emitSubscriptAssign(
2915 const SourceRange& stmtRange,
2916 const Subscript& lhs,
2917 const Expr& rhs) {
2918 emitSubscriptAssign(stmtRange, lhs, NamedValue(rhs.range(), emitExpr(rhs)));
2919 }
2920
emitSubscriptAssigntorch::jit::to_ir2921 void emitSubscriptAssign(
2922 const SourceRange& stmtRange,
2923 const Subscript& lhs,
2924 const NamedValue& rhs) {
2925 // First check the base value.
2926 auto sliceable = emitExpr(lhs.value());
2927
2928 // If it's a tensor, copy the RHS data into it
2929 if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
2930 // Handle multi-dimensional slicing: first emit int/slice indexing
2931 // TODO: the Python equivalent code has special-cased copy_to
2932 // broadcasting to match NumPy semantics (see PR#4853). We can't
2933 // replicate that without knowing the size of the Tensor; so really that
2934 // code should be moved into the aten function
2935 auto [sliced, tensorIndices] = emitIntAndSliceIndexing(
2936 lhs.range(), sliceable, lhs.subscript_exprs());
2937
2938 const auto slicedArg = NamedValue(lhs.range(), sliced);
2939
2940 // rhs must be a tensor, implicitly convert int/float/complex/bool
2941 const auto convertedRhs = emitValueToTensor(rhs, slicedArg);
2942
2943 if (tensorIndices.empty()) {
2944 // Common case: we only tried to index with int and slices. Copy the
2945 // RHS into the resulting tensor.
2946 graph->insert(aten::copy_, {slicedArg, convertedRhs}, {}, stmtRange);
2947 } else {
2948 // Special case: we tried to do "advanced indexing" with a tensor.
2949 // Dispatch to `aten::index_put_` with tensorindices of Tensor?[]
2950 const auto indices = graph
2951 ->insertNode(graph->createList(
2952 OptionalType::ofTensor(), tensorIndices))
2953 ->output();
2954
2955 graph->insert(
2956 aten::index_put_,
2957 {slicedArg, indices, convertedRhs},
2958 {},
2959 stmtRange);
2960 }
2961 // Otherwise, this is a list or a classtype.
2962 // Dispatch to aten::_set_item to both select and assign
2963 } else {
2964 const auto subscript = lhs.subscript_exprs();
2965 if (subscript.size() != 1 || subscript[0].kind() == TK_SLICE_EXPR) {
2966 throw(
2967 ErrorReport(subscript) << "Sliced expression not yet supported for"
2968 << " subscripted assignment. "
2969 << "File a bug if you want this");
2970 }
2971 if (sliceable->type()->isSubtypeOf(*AnyTupleType::get())) {
2972 throw(
2973 ErrorReport(lhs) << sliceable->type()->repr_str()
2974 << " does not support subscripted assignment");
2975 }
2976
2977 std::vector<NamedValue> args;
2978 args.emplace_back(lhs.value().range(), "self", sliceable);
2979 args.emplace_back(
2980 lhs.subscript_exprs().range(), "idx", emitExpr(subscript[0]));
2981 args.push_back(rhs);
2982 makeMagic(
2983 "__setitem__",
2984 std::make_shared<BuiltinFunction>(aten::_set_item, std::nullopt))
2985 ->call(stmtRange, method, args, {}, 0);
2986 }
2987 }
2988
emitTupleAssigntorch::jit::to_ir2989 void emitTupleAssign(const TupleLiteral& tl, const Expr& rhs) {
2990 size_t n_binders = tl.inputs().size();
2991 bool starred_unpack = validateAssignLhsExpr(tl.inputs(), tl.range());
2992 if (starred_unpack)
2993 n_binders--;
2994 auto output = emitSugaredExpr(rhs, n_binders);
2995 emitTupleAssign(tl, output, rhs.range(), n_binders, starred_unpack);
2996 }
2997
emitTupleAssigntorch::jit::to_ir2998 void emitTupleAssign(
2999 const TupleLiteral& tl,
3000 const SugaredValuePtr& rhs_output,
3001 const SourceRange& rhs_loc,
3002 size_t n_binders,
3003 bool starred_unpack) {
3004 auto outputs = rhs_output->asTuple(
3005 rhs_loc,
3006 method,
3007 starred_unpack ? std::nullopt : std::optional<size_t>{n_binders});
3008 if (outputs.size() < n_binders) {
3009 throw(
3010 ErrorReport(tl) << "need " << (starred_unpack ? "at least " : "")
3011 << n_binders << " values to unpack but found only "
3012 << outputs.size());
3013 }
3014 if (outputs.size() > n_binders && !starred_unpack) {
3015 throw(
3016 ErrorReport(tl) << "too many values to unpack: need " << n_binders
3017 << " but found " << outputs.size());
3018 }
3019
3020 emitExprsAssign(tl.inputs(), outputs, rhs_loc, n_binders);
3021 }
3022
emitExprsAssigntorch::jit::to_ir3023 void emitExprsAssign(
3024 const List<Expr>& lhs_exprs,
3025 const at::ArrayRef<SugaredValuePtr> outputs,
3026 const SourceRange& rhs_loc,
3027 size_t n_binders) {
3028 size_t i = 0;
3029 for (auto assignee : lhs_exprs) {
3030 switch (assignee.kind()) {
3031 case TK_SUBSCRIPT:
3032 emitSubscriptAssign(
3033 rhs_loc,
3034 Subscript(assignee),
3035 NamedValue(rhs_loc, outputs.at(i)->asValue(rhs_loc, method)));
3036 i++;
3037 break;
3038 case TK_VAR:
3039 environment_stack->setSugaredVar(
3040 assignee.range(),
3041 Var(assignee).name().name(),
3042 outputs.at(i),
3043 /*annotated_type=*/nullptr);
3044 i++;
3045 break;
3046 case TK_STARRED: {
3047 auto var = Starred(assignee).expr();
3048 if (var.kind() != TK_VAR) {
3049 throw(
3050 ErrorReport(var) << "Cannot pack a tuple into a non-variable");
3051 }
3052 size_t n_matched = outputs.size() - n_binders;
3053 ArrayRef<std::shared_ptr<SugaredValue>> outputs_ref = outputs;
3054 auto values = fmap(
3055 outputs_ref.slice(i, n_matched),
3056 [&](const std::shared_ptr<SugaredValue>& v) {
3057 return v->asValue(assignee.range(), method);
3058 });
3059 auto tup = graph->insertNode(graph->createTuple(values))->output();
3060 environment_stack->setVar(var.range(), Var(var).name().name(), tup);
3061 i += n_matched;
3062 } break;
3063 case TK_TUPLE_LITERAL: {
3064 // recursively emit tuple assignments on tuple literal input
3065 TupleLiteral sub_tl = TupleLiteral(assignee);
3066 size_t sub_n_binders = sub_tl.inputs().size();
3067 bool sub_starred_unpack =
3068 validateAssignLhsExpr(sub_tl.inputs(), sub_tl.range());
3069 if (sub_starred_unpack)
3070 sub_n_binders--;
3071 emitTupleAssign(
3072 sub_tl,
3073 outputs.at(i),
3074 rhs_loc,
3075 sub_n_binders,
3076 sub_starred_unpack);
3077 i++;
3078 } break;
3079 case '.': {
3080 emitSelectAssign(assignee, outputs.at(i), rhs_loc);
3081 i++;
3082 } break;
3083 default:
3084 throw(
3085 ErrorReport(assignee)
3086 << "unexpected expression on the left-hand side");
3087 }
3088 }
3089 }
3090
emitAssignmenttorch::jit::to_ir3091 void emitAssignment(const Assign& stmt) {
3092 if (stmt.lhs_list().size() == 1) {
3093 return emitSingleAssignment(stmt);
3094 }
3095 // multiple assign & annotated type not supported in python
3096 TORCH_INTERNAL_ASSERT(stmt.lhs_list().size() > 1 && !stmt.type().present());
3097 // a = b = expr()
3098 // the semantics of multiple assignment is that expr() is emitted once, then
3099 // from left to right the assignments are made
3100 const auto tmp_name = createTempName("$tmp_assign_");
3101 environment_stack->setSugaredVar(
3102 stmt.rhs().range(),
3103 tmp_name,
3104 emitSugaredExpr(stmt.rhs().get(), 1),
3105 /*annotated_type=*/nullptr);
3106 auto ident = Var::create(
3107 stmt.rhs().range(), Ident::create(stmt.rhs().range(), tmp_name));
3108 for (auto expr : stmt.lhs_list()) {
3109 emitSingleAssignment(Assign::create(
3110 stmt.range(),
3111 List<Expr>::create(expr.range(), {expr}),
3112 Maybe<Expr>::create(stmt.rhs().range(), ident),
3113 Maybe<Expr>::create(stmt.range())));
3114 }
3115 }
3116
emitSingleAssignmenttorch::jit::to_ir3117 void emitSingleAssignment(const Assign& stmt) {
3118 if (!stmt.rhs().present()) {
3119 throw(
3120 ErrorReport(stmt.range())
3121 << "For an assignment, expected an expression on the right-hand side");
3122 }
3123 const Expr& rhs = stmt.rhs().get();
3124 switch (stmt.lhs().kind()) {
3125 case TK_VAR: {
3126 auto v = Var(stmt.lhs());
3127 TypePtr type = nullptr;
3128 if (stmt.type().present()) {
3129 type = typeParser_.parseTypeFromExpr(stmt.type().get());
3130 }
3131 auto rhs_sugared_val = emitSugaredExpr(rhs, 1, type);
3132 // START BC HACK
3133 //
3134 // For old serialized quantized RNN modules, switch
3135 // quantized::linear_prepack to quantized::linear_prepack_legacy. We
3136 // changed linear_prepack to return a TorchBind class and not a
3137 // cpp_custom_type_hack tensor anymore, but the old serialized models
3138 // are tightly coupled with the type_hack version. If we still create a
3139 // Tensor here, then the quantized_lstm.legacy overload can kick in in
3140 // forward_impl(), and the module will still run correctly.
3141 if (method.qualname() ==
3142 "__torch__.torch.nn.quantized.dynamic.modules.rnn.PackedParameter.__setstate__") {
3143 if (auto sv =
3144 std::dynamic_pointer_cast<SimpleValue>(rhs_sugared_val)) {
3145 Node* rhs_node = sv->getValue()->node();
3146 if (rhs_node->kind() ==
3147 Symbol::fromQualString("quantized::linear_prepack")) {
3148 std::vector<NamedValue> inputs;
3149 for (Value* i : rhs_node->inputs()) {
3150 inputs.emplace_back(i);
3151 }
3152 Value* new_val = rhs_node->owningGraph()->insert(
3153 Symbol::fromQualString("quantized::linear_prepack_legacy"),
3154 inputs,
3155 {},
3156 rhs_node->sourceRange());
3157 rhs_sugared_val = std::make_shared<SimpleValue>(new_val);
3158 }
3159 }
3160 }
3161 // END BC HACK
3162 environment_stack->setSugaredVar(
3163 v.range(),
3164 v.name().name(),
3165 std::move(rhs_sugared_val),
3166 /*annotated_type=*/type);
3167 } break;
3168 case TK_TUPLE_LITERAL:
3169 emitTupleAssign(TupleLiteral(stmt.lhs()), rhs);
3170 break;
3171 case '.':
3172 emitSelectAssign(stmt);
3173 break;
3174 case TK_SUBSCRIPT:
3175 emitSubscriptAssign(stmt.range(), Subscript(stmt.lhs()), rhs);
3176 break;
3177 default:
3178 throw(
3179 ErrorReport(stmt.lhs())
3180 << "unexpected expression on left-hand side of assignment");
3181 }
3182 }
3183
emitSelectAssigntorch::jit::to_ir3184 void emitSelectAssign(const Assign& stmt) {
3185 if (!stmt.rhs().present()) {
3186 throw(ErrorReport(stmt.range()) << "Expected RHS for assignment");
3187 }
3188
3189 TypePtr type_hint = nullptr;
3190 if (stmt.type().present()) {
3191 type_hint = typeParser_.parseTypeFromExpr(stmt.type().get());
3192 }
3193 const auto lhs = Select(stmt.lhs());
3194 auto lhsObject = emitSugaredExpr(lhs.value(), 1);
3195 const auto rhsValue = emitSugaredExpr(stmt.rhs().get(), 1, type_hint)
3196 ->asValue(stmt.rhs().range(), method);
3197 lhsObject->setAttr(stmt.range(), method, lhs.selector().name(), rhsValue);
3198 }
3199
emitSelectAssigntorch::jit::to_ir3200 void emitSelectAssign(
3201 const Expr& lhs,
3202 const SugaredValuePtr& rhs,
3203 const SourceRange& loc) {
3204 const auto lhs_select = Select(lhs);
3205 auto lhs_sv = emitSugaredExpr(lhs_select.value(), 1);
3206 const auto rhs_value = rhs->asValue(loc, method);
3207 lhs_sv->setAttr(loc, method, lhs_select.selector().name(), rhs_value);
3208 }
3209
getNodeKindtorch::jit::to_ir3210 NodeKind getNodeKind(int kind, size_t ninputs) {
3211 switch (kind) {
3212 case '+':
3213 return aten::add;
3214 case '-':
3215 return aten::sub;
3216 case TK_UNARY_MINUS:
3217 return aten::neg;
3218 case '*':
3219 return aten::mul;
3220 case TK_POW:
3221 return aten::pow;
3222 case '@':
3223 return aten::matmul;
3224 case TK_STARRED:
3225 return prim::Starred;
3226 case '/':
3227 return aten::div;
3228 case '%':
3229 return aten::remainder;
3230 case TK_NE:
3231 return aten::ne;
3232 case TK_EQ:
3233 return aten::eq;
3234 case '<':
3235 return aten::lt;
3236 case '>':
3237 return aten::gt;
3238 case TK_LE:
3239 return aten::le;
3240 case TK_GE:
3241 return aten::ge;
3242 case TK_AND:
3243 return aten::__and__;
3244 case TK_OR:
3245 return aten::__or__;
3246 case TK_IS:
3247 return aten::__is__;
3248 case TK_ISNOT:
3249 return aten::__isnot__;
3250 case TK_NOT:
3251 return aten::__not__;
3252 case TK_FLOOR_DIV:
3253 return aten::floordiv;
3254 case TK_LSHIFT:
3255 return aten::__lshift__;
3256 case TK_RSHIFT:
3257 return aten::__rshift__;
3258 case '&':
3259 return aten::__and__;
3260 case '|':
3261 return aten::__or__;
3262 case '^':
3263 return aten::__xor__;
3264 case TK_IN:
3265 return aten::__contains__;
3266 default:
3267 throw std::runtime_error("unknown kind " + std::to_string(kind));
3268 }
3269 }
3270
getOperatorOverloadtorch::jit::to_ir3271 std::string getOperatorOverload(int kind, size_t ninputs) {
3272 switch (kind) {
3273 case '+':
3274 return "__add__";
3275 case '-':
3276 return "__sub__";
3277 case TK_UNARY_MINUS:
3278 return "__neg__";
3279 case '~':
3280 return "__invert__";
3281 case '*':
3282 return "__mul__";
3283 case TK_POW:
3284 return "__pow__";
3285 case '/':
3286 return "__truediv__";
3287 case '%':
3288 return "__mod__";
3289 case TK_NE:
3290 return "__ne__";
3291 case TK_EQ:
3292 return "__eq__";
3293 case '<':
3294 return "__lt__";
3295 case '>':
3296 return "__gt__";
3297 case TK_LE:
3298 return "__le__";
3299 case TK_GE:
3300 return "__ge__";
3301 case '&':
3302 return "__and__";
3303 case '|':
3304 return "__or__";
3305 case '^':
3306 return "__xor__";
3307 case TK_IN:
3308 return "__contains__";
3309 case TK_LSHIFT:
3310 return "__lshift__";
3311 case TK_RSHIFT:
3312 return "__rshift__";
3313 default:
3314 throw std::runtime_error("unknown kind " + std::to_string(kind));
3315 }
3316 }
3317
getNamedValuestorch::jit::to_ir3318 std::vector<NamedValue> getNamedValues(
3319 const TreeList& trees,
3320 bool maybe_unpack) {
3321 std::vector<NamedValue> values;
3322 for (const auto& tree : trees) {
3323 if (maybe_unpack && tree->kind() == TK_STARRED) {
3324 auto starred = Starred(tree);
3325 auto entries = emitSugaredExpr(starred.expr(), 1)
3326 ->asTuple(starred.range(), method);
3327 for (const auto& entry : entries) {
3328 values.emplace_back(
3329 tree->range(), entry->asValue(starred.range(), method));
3330 }
3331 } else {
3332 values.emplace_back(tree->range(), emitExpr(Expr(tree)));
3333 }
3334 }
3335 return values;
3336 }
getNamedValuestorch::jit::to_ir3337 std::vector<NamedValue> getNamedValues(
3338 const List<Expr>& trees,
3339 bool maybe_unpack) {
3340 return getNamedValues(trees.tree()->trees(), maybe_unpack);
3341 }
3342
getValuestorch::jit::to_ir3343 std::vector<Value*> getValues(const TreeList& trees, bool maybe_unpack) {
3344 return toValues(*graph, getNamedValues(trees, maybe_unpack));
3345 }
getValuestorch::jit::to_ir3346 std::vector<Value*> getValues(const List<Expr>& trees, bool maybe_unpack) {
3347 return getValues(trees.tree()->trees(), maybe_unpack);
3348 }
3349
emitAttributestorch::jit::to_ir3350 std::vector<NamedValue> emitAttributes(const List<Attribute>& attributes) {
3351 return fmap(attributes, [&](const Attribute& attr) {
3352 return NamedValue(
3353 attr.range(), attr.name().name(), emitExpr(attr.value()));
3354 });
3355 }
3356
checkApplyNumInputstorch::jit::to_ir3357 void checkApplyNumInputs(const Apply& apply, size_t expected_inputs) {
3358 const SourceRange& loc = apply.range();
3359 if (apply.inputs().size() != expected_inputs) {
3360 throw(
3361 ErrorReport(loc) << Var(apply.callee()).name().name()
3362 << " expected exactly " << expected_inputs
3363 << " arguments but found " << apply.inputs().size());
3364 }
3365 if (!apply.attributes().empty()) {
3366 throw(
3367 ErrorReport(loc) << Var(apply.callee()).name().name()
3368 << " takes no keyword arguments");
3369 }
3370 }
3371
checkApplyNumInputsRangetorch::jit::to_ir3372 void checkApplyNumInputsRange(
3373 const Apply& apply,
3374 size_t min_expected_inputs,
3375 size_t max_expected_inputs) {
3376 const SourceRange& loc = apply.range();
3377 size_t position_arg_size = apply.inputs().size();
3378 if (position_arg_size < min_expected_inputs ||
3379 position_arg_size > max_expected_inputs) {
3380 throw(
3381 ErrorReport(loc) << Var(apply.callee()).name().name()
3382 << " expected to have number of arguments between "
3383 << min_expected_inputs << " and "
3384 << max_expected_inputs << " but found "
3385 << position_arg_size);
3386 }
3387 if (!apply.attributes().empty()) {
3388 throw(
3389 ErrorReport(loc) << Var(apply.callee()).name().name()
3390 << " takes no keyword arguments");
3391 }
3392 }
3393
emitApplyExprtorch::jit::to_ir3394 std::shared_ptr<SugaredValue> emitApplyExpr(
3395 Apply& apply,
3396 size_t n_binders,
3397 const TypePtr& type_hint = nullptr) {
3398 auto sv = emitSugaredExpr(apply.callee(), 1);
3399 auto loc = apply.callee().range();
3400 if (auto special_form = dynamic_cast<SpecialFormValue*>(sv.get())) {
3401 return emitApplySpecialForm(special_form->form(), apply, sv, type_hint);
3402 }
3403 auto args = getNamedValues(apply.inputs(), true);
3404 auto kwargs = emitAttributes(apply.attributes());
3405 return sv->call(loc, method, args, kwargs, n_binders);
3406 }
3407
3408 // this function handles expressions that look like apply statements
3409 // but have special evaluation rules for the arguments.
3410 // when adding a new case, only add a special form if it cannot be expressed
3411 // using the standard SugaredValue::call function, which enforces normal
3412 // evaluation order.
emitApplySpecialFormtorch::jit::to_ir3413 std::shared_ptr<SugaredValue> emitApplySpecialForm(
3414 Symbol form,
3415 Apply& apply,
3416 const std::shared_ptr<SugaredValue>& sv,
3417 const TypePtr& type_hint = nullptr) {
3418 switch (form) {
3419 case prim::fork: {
3420 auto& trees = apply.inputs().tree()->trees();
3421 if (trees.empty()) {
3422 throw(
3423 ErrorReport(apply) << "Expected at least one argument to fork()");
3424 }
3425 auto forked = emitSugaredExpr(Expr(trees[0]), 1);
3426 TreeList sliced_trees(trees.begin() + 1, trees.end());
3427 auto args = getNamedValues(sliced_trees, true);
3428 auto kwargs = emitAttributes(apply.attributes());
3429 return emitForkExpr(apply.range(), forked, args, kwargs);
3430 }
3431 case prim::awaitable: {
3432 auto tree = apply.inputs().tree();
3433 if (!tree || tree->trees().empty()) {
3434 throw(
3435 ErrorReport(apply)
3436 << "Expected at least one argument to awaitable()");
3437 }
3438 auto& trees = tree->trees();
3439 auto awaited = emitSugaredExpr(Expr(trees[0]), 1);
3440 TreeList sliced_trees(trees.begin() + 1, trees.end());
3441 auto args = getNamedValues(sliced_trees, true);
3442 auto kwargs = emitAttributes(apply.attributes());
3443 return emitAwaitableExpr(apply.range(), awaited, args, kwargs);
3444 }
3445 case prim::annotate: {
3446 checkApplyNumInputs(apply, 2);
3447 TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]);
3448 Value* expr = tryConvertToType(
3449 apply.range(),
3450 *graph,
3451 type,
3452 emitExpr(apply.inputs()[1], type),
3453 /*allow_conversions=*/true);
3454
3455 std::stringstream why_not;
3456 if (!expr->type()->isSubtypeOfExt(*type, &why_not)) {
3457 throw(
3458 ErrorReport(apply.inputs())
3459 << "expected an expression of type " << type->repr_str()
3460 << " but found " << expr->type()->repr_str() << "\n"
3461 << why_not.str());
3462 }
3463
3464 // None is a subtype of Optional[T], but we want to remember what T is
3465 // after annotation so that variables assigned to this None will still
3466 // get the right type. To do this, we make a None constant that
3467 // has the type Optional[T]
3468 if ((type->kind() == OptionalType::Kind ||
3469 (type->kind() == UnionType::Kind &&
3470 type->expect<UnionType>()->canHoldType(*NoneType::get()))) &&
3471 expr->type()->isSubtypeOf(*NoneType::get())) {
3472 Node* none = graph->createNone();
3473 none->output()->setType(type);
3474 graph->insertNode(none);
3475 expr = none->output();
3476 }
3477
3478 return std::make_shared<SimpleValue>(expr);
3479 }
3480 case prim::rpc_async:
3481 case prim::rpc_sync:
3482 case prim::rpc_remote: {
3483 return emitRpcExpr(apply, form);
3484 }
3485 case prim::unchecked_cast: {
3486 checkApplyNumInputs(apply, 2);
3487 TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]);
3488 Value* v = emitExpr(apply.inputs()[1]);
3489 // avoid generating nested unchecked_casts because they are already
3490 // inserted during serialization
3491 if (v->node()->kind() != prim::unchecked_cast || *v->type() != *type) {
3492 v = graph->insertUncheckedCast(v, type);
3493 }
3494 return std::make_shared<SimpleValue>(v);
3495 } break;
3496 case prim::GetAttr: {
3497 checkApplyNumInputsRange(apply, 2, 3);
3498 auto obj = emitSugaredExpr(apply.inputs()[0], 1);
3499 auto selector = apply.inputs()[1];
3500 if (selector.kind() != TK_STRINGLITERAL) {
3501 throw(
3502 ErrorReport(apply)
3503 << "getattr's second argument must be a string literal");
3504 }
3505 const std::string& name = StringLiteral(selector).text();
3506
3507 if (apply.inputs().size() == 2) {
3508 return obj->attr(apply.range(), method, name);
3509 } else {
3510 // 3 inputs form of getattr, the third argument is the default value
3511 // to return when attribute is not found
3512 if (obj->hasAttr(apply.range(), method, name)) {
3513 return obj->attr(apply.range(), method, name);
3514 } else {
3515 // attribute not found, just default val (3rd arg)
3516 return emitSugaredExpr(apply.inputs()[2], 1);
3517 }
3518 }
3519 } break;
3520 case prim::Uninitialized: {
3521 checkApplyNumInputs(apply, 1);
3522 TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]);
3523 auto out = graph->insertNode(graph->createUninitialized(type))
3524 ->setSourceRange(apply.range());
3525 return std::make_shared<SimpleValue>(out->output());
3526 }
3527 case prim::TupleConstruct: {
3528 checkApplyNumInputs(apply, 1);
3529 auto arg = emitSugaredExpr(apply.inputs()[0], 1);
3530 auto inputs = arg->asTuple(apply.range(), method);
3531 auto inp_values = fmap(inputs, [&](const SugaredValuePtr& sv) {
3532 return sv->asValue(apply.range(), method);
3533 });
3534 return std::make_shared<SimpleValue>(
3535 graph->insertNode(graph->createTuple(inp_values))->output());
3536 }
3537 case prim::LegacyTypedConstructor: {
3538 // see legacy_tensor_generic_ctor_new
3539 // These legacy constructors do not follow schemas that can be
3540 // typed in native_functions.yaml / JIT type signature and are handled
3541 // here. Only the two common cases are handled initially:
3542 // "new(IntArrayRef size, *, Device? device=None)",
3543 // "new(PyObject* data, *, Device? device=None)",
3544 // Note: device argument is unused in the kernel
3545 auto args = getValues(apply.inputs(), true);
3546 auto kwargs = emitAttributes(apply.attributes());
3547 auto get_base_error_msg = [&]() {
3548 std::stringstream base_error_msg;
3549 base_error_msg
3550 << "Legacy Tensor Constructor only supports two schemas in TorchScript: \n";
3551 base_error_msg
3552 << "'new(IntArrayRef size, *, Device? device=None)',\n";
3553 base_error_msg << "'new(PyObject* data, *, Device? device=None)\n'";
3554 return base_error_msg;
3555 };
3556 if (kwargs.size() == 1 && kwargs[0].name() != "device") {
3557 throw(
3558 ErrorReport(apply) << get_base_error_msg().str() << "Got kwarg "
3559 << kwargs[0].name());
3560 }
3561 if (kwargs.size() > 1) {
3562 throw(
3563 ErrorReport(apply)
3564 << get_base_error_msg().str() << "Got multiple kwargs\n");
3565 }
3566 auto dtype = dynamic_cast<LegacyTensorConstructor*>(sv.get())->dtype();
3567 auto dtype_ivalue = graph->insertConstant(dtype);
3568
3569 // supporting "new(IntArrayRef size, *, Device? device=None)", through
3570 // empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout?
3571 // layout=None, Device? device=None, bool? pin_memory=None,
3572 // MemoryFormat? memory_format=None) -> Tensor
3573 bool all_ints = std::all_of(args.begin(), args.end(), [](Value* v) {
3574 return v->type()->cast<IntType>();
3575 });
3576 if (args.empty()) {
3577 // empty inputs == torch.tensor([], dtype=....)
3578 auto inp_list =
3579 graph->insertNode(graph->createList(IntType::get(), {}))
3580 ->output();
3581 return std::make_shared<SimpleValue>(graph->insert(
3582 aten::tensor,
3583 {inp_list},
3584 {NamedValue(apply.range(), "dtype", dtype_ivalue)}));
3585 } else if (all_ints) {
3586 auto inp_list =
3587 graph->insertNode(graph->createList(IntType::get(), args))
3588 ->output();
3589 return std::make_shared<SimpleValue>(graph->insert(
3590 aten::empty,
3591 {inp_list},
3592 {NamedValue(apply.range(), "dtype", dtype_ivalue)}));
3593 } else if (args.size() == 1) {
3594 return std::make_shared<SimpleValue>(graph->insert(
3595 aten::tensor,
3596 {args[0]},
3597 {NamedValue(apply.range(), "dtype", dtype_ivalue)}));
3598 } else {
3599 throw(
3600 ErrorReport(apply)
3601 << get_base_error_msg().str()
3602 << "Got multiple positional arguments that were not all integers");
3603 }
3604 }
3605 case prim::isinstance: {
3606 checkApplyNumInputs(apply, 2);
3607 auto result = emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
3608 return std::make_shared<SimpleValue>(result.value());
3609 }
3610 case prim::tolist: {
3611 auto select = Select(apply.callee());
3612 auto value = select.value();
3613 auto operand = emitSugaredExpr(value, 1);
3614
3615 if (!type_hint) {
3616 throw(
3617 ErrorReport(apply)
3618 << "Expected type hint for result of tolist()");
3619 }
3620
3621 return std::make_shared<SimpleValue>(graph->insertToList(
3622 operand->asValue(value.range(), method), type_hint));
3623 }
3624 case prim::HasAttr: {
3625 checkApplyNumInputs(apply, 2);
3626 const auto result = emitHasAttr(apply.inputs()[0], apply.inputs()[1]);
3627 return std::make_shared<SimpleValue>(result.value());
3628 } break;
3629 // This represents the "__new__" method on classes
3630 // because it takes a ClassValue as input.
3631 // So if we see:
3632 // Foo.__new__(Foo)
3633 // Foo is a ClassValue, calling `attr("__new__")` will return a
3634 // CreateObject special form.
3635 case prim::CreateObject: {
3636 if (apply.inputs().size() != 1) {
3637 throw(ErrorReport(apply) << "Only one argument to __new__ allowed");
3638 }
3639 auto arg = emitSugaredExpr(apply.inputs()[0], 1);
3640 auto class_arg = dynamic_cast<ClassValue*>(arg.get());
3641 if (!class_arg) {
3642 throw(
3643 ErrorReport(apply)
3644 << "Expected class value as argument to __new__, got "
3645 << arg->kind() << " instead");
3646 }
3647 auto createNode =
3648 graph->insertNode(graph->createObject(class_arg->type_));
3649 createNode->setSourceRange(apply.range());
3650 return std::make_shared<SimpleValue>(createNode->output());
3651 }
3652 // We construct the iterable tree here using the IterableTree
3653 // SugaredValue, The tree consists of SimpleValue, RangeValue or
3654 // IterableTree: For SimpleValues(List, Dict, etc) or RangeValue. We will
3655 // make them as tree leaves since we could get the loop information from
3656 // len() and get_item(). For IterableTree like zip(), enumerate(), we can
3657 // model them as a combination of leaves, and we emit a IterableTree value
3658 // to record the tree information
3659 case prim::range: {
3660 std::vector<Value*> input_vals =
3661 getValues(apply.inputs(), /*maybe_unpack=*/true);
3662 return std::make_shared<RangeValue>(apply.range(), method, input_vals);
3663 }
3664 case prim::enumerate: {
3665 const SourceRange& loc = apply.range();
3666 auto inputs = apply.inputs();
3667 auto input_size = inputs.size();
3668 auto attributes = apply.attributes();
3669 auto attribute_size = attributes.size();
3670 // enumerate(x) can be rewrite as subtrees:
3671 // IterableTree(RangeValue(0, math.inf), SimpleValue(x))
3672 Value* start_index = nullptr;
3673 if (input_size == 0) {
3674 throw(
3675 ErrorReport(loc)
3676 << "enumerate expected at least 1 arguments, got 0");
3677 }
3678
3679 if (input_size == 2) {
3680 start_index = emitSugaredExpr(inputs[1], 1)->asValue(loc, method);
3681 }
3682 auto arg_size = input_size + attribute_size;
3683 if (arg_size > 2) {
3684 throw(
3685 ErrorReport(loc)
3686 << "enumerate expected at most 2 arguments, got " << arg_size);
3687 }
3688
3689 if (attribute_size == 1) {
3690 if (attributes[0].name().name() != "start") {
3691 throw(
3692 ErrorReport(loc)
3693 << "enumerate expected kwarg name 'start', got '"
3694 << attributes[0].name().name() << "'");
3695 }
3696 start_index =
3697 emitSugaredExpr(attributes[0].value(), 1)->asValue(loc, method);
3698 }
3699
3700 std::vector<Value*> range_inputs;
3701 if (start_index != nullptr) {
3702 range_inputs.emplace_back(start_index);
3703 }
3704 Value* end = materializeConstant(
3705 std::numeric_limits<int64_t>::max(),
3706 *graph,
3707 loc,
3708 integral_constants);
3709 range_inputs.emplace_back(end);
3710 SugaredValuePtr expr_sv = emitSugaredExpr(inputs[0], 1);
3711 auto iterable_value = expr_sv->iter(loc, method);
3712
3713 // range should have the same static length as the other iterable
3714 std::optional<int64_t> iter_static_len = iterable_value->staticLen();
3715 SugaredValuePtr range_sv = std::make_shared<RangeValue>(
3716 loc, method, range_inputs, iter_static_len);
3717
3718 auto tree = std::make_shared<IterableTree>();
3719 tree->addChild(loc, method, range_sv);
3720 tree->addChild(loc, method, iterable_value);
3721 return tree;
3722 }
3723 case prim::zip: {
3724 // zip(x, y) can be rewrite as subtrees:
3725 // IterableTree(IterableTree(x), IterableTree(y))
3726 auto inputs = apply.inputs();
3727 if (inputs.empty()) {
3728 throw(
3729 ErrorReport(apply) << "zip expected at least 1 arguments, got 0");
3730 }
3731 auto iterable_tree = std::make_shared<IterableTree>();
3732 for (Expr expr : inputs) {
3733 auto iterable = emitSugaredExpr(expr, 1)->iter(apply.range(), method);
3734 iterable_tree->addChild(apply.range(), method, iterable);
3735 }
3736 return iterable_tree;
3737 }
3738 case prim::list: {
3739 return emitApplySpecialFormForList(apply, type_hint);
3740 }
3741 case prim::dict: {
3742 return emitApplySpecialFormForDict(apply, type_hint);
3743 }
3744 case aten::index: {
3745 const SourceRange& loc = apply.range();
3746 auto select = Select(apply.callee());
3747 auto self = emitSugaredExpr(select.value(), 1)->asValue(loc, method);
3748
3749 auto inputs = apply.inputs();
3750 if (inputs.size() != 1) {
3751 throw(
3752 ErrorReport(apply)
3753 << "__getitem__ expected exactly 1 arguments, got "
3754 << inputs.size());
3755 }
3756 auto input =
3757 emitSugaredExpr(apply.inputs()[0], 1)->asValue(loc, method);
3758 if (input->type()->kind() == TypeKind::TupleType) {
3759 return std::make_shared<SimpleValue>(
3760 emitIndex(loc, self, createTupleUnpack(input)));
3761 }
3762 return std::make_shared<SimpleValue>(emitIndex(loc, self, {input}));
3763 }
3764 default:
3765 TORCH_INTERNAL_ASSERT(false, "unknown special form: ", form);
3766 }
3767 }
3768
emitApplySpecialFormForListtorch::jit::to_ir3769 std::shared_ptr<SugaredValue> emitApplySpecialFormForList(
3770 Apply& apply,
3771 const TypePtr& type_hint = nullptr) {
3772 if (apply.inputs().empty()) {
3773 TypePtr type = type_hint ? type_hint : ListType::ofTensors();
3774 if (!type->cast<ListType>()) {
3775 throw(
3776 ErrorReport(apply.range())
3777 << "Expected list type annotation for list(), found "
3778 << type_hint->repr_str());
3779 }
3780 return std::make_shared<SimpleValue>(
3781 graph
3782 ->insertNode(graph->createList(
3783 type->expectRef<ListType>().getElementType(), {}))
3784 ->output());
3785 }
3786 // list(iter) desugars to [_elem for _elem in iter]
3787 checkApplyNumInputs(apply, 1);
3788 auto iter_input = emitSugaredExpr(apply.inputs()[0], 1);
3789
3790 // aten::list builtin op is registered for List and Str input
3791 // dispatch to the builtin op to avoid perf slowdown on existing uses
3792 if (auto simple = asSimple(iter_input)) {
3793 if (simple->type()->cast<ListType>() ||
3794 simple->type()->cast<StringType>()) {
3795 return std::make_shared<SimpleValue>(emitBuiltinCall(
3796 apply.range(), *method.graph(), aten::list, {simple}, {}));
3797 }
3798 }
3799 const std::string& iter_name = createTempName("$_iter");
3800 environment_stack->setSugaredVar(
3801 apply.range(),
3802 iter_name,
3803 iter_input,
3804 /*annotated_type=*/nullptr);
3805
3806 const std::string& elem_name = createTempName("$_elem");
3807 auto ident =
3808 Var::create(apply.range(), Ident::create(apply.range(), elem_name));
3809 auto iter =
3810 Var::create(apply.range(), Ident::create(apply.range(), iter_name));
3811 auto lc = ListComp::create(apply.range(), ident, ident, iter);
3812 return std::make_shared<SimpleValue>(emitListComprehension(lc, type_hint));
3813 }
3814
emitApplySpecialFormForDicttorch::jit::to_ir3815 std::shared_ptr<SugaredValue> emitApplySpecialFormForDict(
3816 Apply& apply,
3817 const TypePtr& type_hint = nullptr) {
3818 auto check_type_assignment_error = [&](const TypePtr& key_type,
3819 const TypePtr& value_type,
3820 const TypePtr& annotated_dict_type) {
3821 std::stringstream ss;
3822 std::stringstream err;
3823
3824 auto annotated_k_type =
3825 annotated_dict_type->expect<DictType>()->getKeyType();
3826 auto annotated_v_type =
3827 annotated_dict_type->expect<DictType>()->getValueType();
3828
3829 const auto is_key_subtype = key_type == annotated_k_type;
3830 const auto is_value_subtype =
3831 value_type->isSubtypeOfExt(annotated_v_type, &ss);
3832
3833 if (!is_key_subtype) {
3834 err << "Generated key type " << key_type->repr_str()
3835 << " did not match the annotated key type, which was "
3836 << annotated_k_type->repr_str() << "\n";
3837 }
3838
3839 if (!is_value_subtype) {
3840 err << "Generated value type " << value_type->repr_str()
3841 << " did not match the annotated value type, which was "
3842 << annotated_v_type->repr_str() << "\n"
3843 << ss.str();
3844 }
3845
3846 if (!is_key_subtype || !is_value_subtype) {
3847 throw(ErrorReport(apply) << err.str());
3848 }
3849 };
3850
3851 auto add_kwargs = [&](Value* dc_value) {
3852 NamedValue self = NamedValue(apply.range(), "self", dc_value);
3853 for (const auto& kwarg : apply.attributes()) {
3854 auto name = StringLiteral::create(kwarg.range(), kwarg.name().name());
3855 auto k = emitExpr(name);
3856 auto v = emitExpr(kwarg.value());
3857 NamedValue input_k = NamedValue(kwarg.range(), "", k);
3858 NamedValue input_v = NamedValue(kwarg.range(), "", v);
3859
3860 check_type_assignment_error(k->type(), v->type(), dc_value->type());
3861
3862 emitBuiltinCall(
3863 kwarg.range(),
3864 *graph,
3865 aten::_set_item,
3866 {self, input_k, input_v},
3867 {});
3868 }
3869 };
3870
3871 auto treat_as_empty_container = [&]() {
3872 // true if `dict()`
3873 if (apply.inputs().empty() && !apply.attributes().empty()) {
3874 return true;
3875 }
3876 // true if `dict({})`
3877 if (!apply.inputs().empty() &&
3878 apply.inputs()[0].kind() == TK_DICT_LITERAL) {
3879 auto dict_lit = DictLiteral(apply.inputs()[0]);
3880 return dict_lit.key_inputs().empty() && dict_lit.value_inputs().empty();
3881 }
3882 // true if `dict([])`
3883 if (!apply.inputs().empty() &&
3884 apply.inputs()[0].kind() == TK_LIST_LITERAL) {
3885 auto list_lit = ListLiteral(apply.inputs()[0]);
3886 return list_lit.inputs().empty();
3887 }
3888 return false;
3889 };
3890
3891 TypePtr annotated_union_type =
3892 type_hint && type_hint->isUnionType() ? type_hint : nullptr;
3893
3894 auto add_union_cast = [&](Value* result) {
3895 Node* n =
3896 graph->insertNode(graph->create(prim::unchecked_cast, {result}));
3897 n->output()->setType(std::move(annotated_union_type));
3898 result = n->output();
3899 };
3900
3901 TypePtr refined_type_hint = type_hint;
3902
3903 std::vector<TypePtr> all_candidates = {};
3904
3905 auto type_match = [&](const TypePtr& t) {
3906 return t->kind() == DictType::Kind;
3907 };
3908
3909 if (type_hint && type_hint->kind() != DictType::Kind) {
3910 refineAndSetUnionTypeHintOrPopulateCandidatesVector(
3911 type_hint,
3912 &refined_type_hint,
3913 &all_candidates,
3914 "Dict",
3915 apply,
3916 type_match,
3917 [] {},
3918 [] {},
3919 /*is_dict_constructor=*/true);
3920 }
3921
3922 if (!all_candidates.empty()) {
3923 throw(
3924 ErrorReport(apply)
3925 << "There are multiple candidate "
3926 << "Dict types in the Union type annotation `"
3927 << type_hint->repr_str()
3928 << "`, and full type inference is not yet supported for the "
3929 << "`dict()` constructor.");
3930 }
3931
3932 // If possible, just cast what we have to a Dict and add the
3933 // kwargs by hand. This is not only the simplest solution; it also
3934 // hits cases like `dict(dict([1, 2, 3]))` or `dict(x)` (where `x`
3935 // is some previously-defined variable)
3936 if (!apply.inputs().empty()) {
3937 // TODO(@ansley): Fix this! We have a weird situation where the
3938 // dict constructor may be handed an internal container literal
3939 // or comprehension, in which case we'd throw an error because
3940 // the lhs type wouldn't match the rhs type (the compiler wouldn't
3941 // be able to tell that this was part of a nested expression). We
3942 // used to get around this by simply not passing `type_hint`, but
3943 // 1) that's bad, and 2) we actually need `type_hint` for
3944 // inference now that Union has been introduced.
3945 std::shared_ptr<SugaredValue> iter_input;
3946 try {
3947 iter_input = emitSugaredExpr(apply.inputs()[0], 1, type_hint);
3948 } catch (const ErrorReport&) {
3949 iter_input = emitSugaredExpr(apply.inputs()[0], 1);
3950 }
3951 if (auto simple = asSimple(iter_input)) {
3952 if (simple->type()->cast<DictType>()) {
3953 auto dc_value = emitBuiltinCall(
3954 apply.range(), *method.graph(), aten::dict, {simple}, {});
3955 add_kwargs(dc_value);
3956 if (annotated_union_type) {
3957 add_union_cast(dc_value);
3958 }
3959 return std::make_shared<SimpleValue>(dc_value);
3960 }
3961 }
3962 }
3963
3964 // If we have a call with an empty container, or if we have a
3965 // call with kwargs only
3966 if (treat_as_empty_container()) {
3967 auto expr_list = List<Expr>::create(apply.range(), {});
3968 apply = Apply::create(
3969 apply.range(), apply.callee(), expr_list, apply.attributes());
3970 }
3971
3972 // If we have a completely empty call to dict()
3973 if (apply.inputs().empty() && apply.attributes().empty()) {
3974 if (!refined_type_hint) {
3975 refined_type_hint =
3976 DictType::create(StringType::get(), TensorType::get());
3977 } else if (!all_candidates.empty()) {
3978 throw(
3979 ErrorReport(apply.range())
3980 << "Cannot determine the type "
3981 << "of an empty dict given the Union annotation `"
3982 << type_hint->repr_str() << "`, which contains multiple "
3983 << "candidate Dict types ");
3984 }
3985
3986 TORCH_CHECK(
3987 refined_type_hint->kind() == DictType::Kind,
3988 "Expected a type annotation "
3989 "of Dict for dict constructor dict(), got ",
3990 type_hint->str());
3991
3992 return std::make_shared<SimpleValue>(
3993 graph
3994 ->insertNode(graph->createDict(
3995 refined_type_hint->expect<DictType>()->getKeyType(),
3996 refined_type_hint->expect<DictType>()->getValueType(),
3997 {},
3998 {}))
3999 ->output());
4000 }
4001
4002 // Special-case logic for if we have a dict comprehension
4003 if (!apply.inputs().empty() && apply.inputs()[0].kind() == TK_DICT_COMP) {
4004 auto dc = DictComp(apply.inputs()[0]);
4005 auto dc_value = emitDictComprehension(dc, refined_type_hint);
4006 add_kwargs(dc_value);
4007 return std::make_shared<SimpleValue>(dc_value);
4008 }
4009
4010 // We can't feasibly register all possible key x value
4011 // combinations of new prim ops for the case that we use the
4012 // constructor with a dict literal. It makes much more sense
4013 // to transform the dict literal into a list of tuples so that
4014 // we can use the existing constructors
4015 if (!apply.inputs().empty() &&
4016 apply.inputs()[0].kind() == TK_DICT_LITERAL) {
4017 auto dict_lit = DictLiteral(apply.inputs()[0]);
4018 std::vector<Expr> zipped;
4019 zipped.reserve(dict_lit.key_inputs().size());
4020 TORCH_INTERNAL_ASSERT(
4021 dict_lit.key_inputs().size() == dict_lit.value_inputs().size());
4022 for (auto key_it = dict_lit.key_inputs().begin(),
4023 val_it = dict_lit.value_inputs().begin();
4024 key_it != dict_lit.key_inputs().end();
4025 ++key_it, ++val_it) {
4026 auto tuple_inputs =
4027 List<Expr>::create(apply.range(), {*key_it, *val_it});
4028 auto tuple = TupleLiteral::create(apply.range(), tuple_inputs);
4029 zipped.push_back(tuple);
4030 }
4031 auto ll_values = List<Expr>::create(apply.range(), zipped);
4032 auto ll = ListLiteral::create(apply.range(), ll_values);
4033 auto expr_list = List<Expr>::create(apply.range(), {ll});
4034 // Change `apply` to a new Apply node holding a list of
4035 // tuples
4036 apply = Apply::create(
4037 apply.range(), apply.callee(), expr_list, apply.attributes());
4038 }
4039
4040 // If we have kwargs to include, we'll take a similar approach
4041 // to the above logic and standardize the Apply node
4042 if (!apply.attributes().empty() &&
4043 (apply.inputs().empty() ||
4044 apply.inputs()[0].kind() == TK_LIST_LITERAL)) {
4045 std::vector<Expr> exprs;
4046 // Gather all the existing tuples in the input iterable
4047 if (!apply.inputs().empty()) {
4048 auto tuple_list = ListLiteral(apply.inputs()[0]).inputs();
4049 for (const auto& tuple : tuple_list) {
4050 exprs.push_back(tuple);
4051 }
4052 }
4053 // Create tuples out of each kwarg and gather them as well
4054 for (const auto& attr : apply.attributes()) {
4055 auto k = StringLiteral::create(apply.range(), attr.name().name());
4056 auto v = attr.value();
4057 auto tuple_inputs = List<Expr>::create(apply.range(), {k, v});
4058 auto tuple = TupleLiteral::create(apply.range(), tuple_inputs);
4059 exprs.push_back(tuple);
4060 }
4061 auto expr_list = List<Expr>::create(apply.range(), {exprs});
4062 auto ll = ListLiteral::create(apply.range(), expr_list);
4063 auto new_inputs = List<Expr>::create(apply.range(), {ll});
4064 auto new_kwargs = List<Attribute>::create(apply.range(), {});
4065 apply =
4066 Apply::create(apply.range(), apply.callee(), new_inputs, new_kwargs);
4067 }
4068
4069 checkApplyNumInputs(apply, 1);
4070
4071 auto iter_input = emitSugaredExpr(apply.inputs()[0], 1);
4072
4073 const std::string& iter_name = createTempName("$_iter");
4074 const std::string& key_name = createTempName("$_key");
4075 const std::string& value_name = createTempName("$_value");
4076
4077 auto key =
4078 Var::create(apply.range(), Ident::create(apply.range(), key_name));
4079 auto value =
4080 Var::create(apply.range(), Ident::create(apply.range(), value_name));
4081 auto target = TupleLiteral::create(
4082 apply.range(), List<Expr>::create(apply.range(), {key, value}));
4083 auto iter =
4084 Var::create(apply.range(), Ident::create(apply.range(), iter_name));
4085
4086 environment_stack->setSugaredVar(
4087 apply.range(),
4088 iter_name,
4089 iter_input,
4090 /*annotated_type=*/nullptr);
4091
4092 auto dc = DictComp::create(apply.range(), key, value, target, iter);
4093 auto result = emitDictComprehension(dc, refined_type_hint);
4094 add_kwargs(result);
4095
4096 if (annotated_union_type) {
4097 add_union_cast(result);
4098 }
4099
4100 return std::make_shared<SimpleValue>(result);
4101 }
4102
emitExprtorch::jit::to_ir4103 Value* emitExpr(const Expr& tree, const TypePtr& type_hint = nullptr) {
4104 // Push the source range of a call in case compiling this function
4105 // triggers an error
4106 ErrorReport::CallStack::update_pending_range(tree.range());
4107 Value* out_val =
4108 emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method);
4109 // AnyType is the only user-exposed type which we don't unify to from
4110 // its subtypes, so we add a cast for use cases like
4111 // x : Any = 1 if cond else "str"
4112 if (type_hint == AnyType::get() && out_val->type() != AnyType::get()) {
4113 out_val = graph->insertUncheckedCast(out_val, type_hint);
4114 }
4115 return out_val;
4116 }
4117
reverseComparisiontorch::jit::to_ir4118 NodeKind reverseComparision(NodeKind kind) {
4119 if (kind == aten::lt) {
4120 return aten::gt;
4121 } else if (kind == aten::le) {
4122 return aten::ge;
4123 } else if (kind == aten::gt) {
4124 return aten::lt;
4125 } else if (kind == aten::ge) {
4126 return aten::le;
4127 }
4128 throw std::runtime_error(
4129 "reverseComparision: unsupported NodeKind. File a bug");
4130 }
4131
4132 // any expression that can produce a SugaredValue is handled here
4133 // expressions that only return a single Value* are handled in emitSimpleExpr
4134 // type_hint is set if there is a type that this value is expected to be
4135 // e.g. a : List[int] = []
4136 // or a = torch.jit.annotate(List[int], [])
4137 // the caller is responsible for checking that the result matches type_hint
4138 // emitSugaredExpr is free to ignore it.
emitSugaredExprtorch::jit::to_ir4139 std::shared_ptr<SugaredValue> emitSugaredExpr(
4140 const Expr& tree,
4141 size_t n_binders,
4142 const TypePtr& type_hint = nullptr) {
4143 switch (tree.kind()) {
4144 case TK_VAR: {
4145 return environment_stack->getSugaredVar(Var(tree).name());
4146 }
4147 case '.': {
4148 auto select = Select(tree);
4149 auto sv = emitSugaredExpr(select.value(), 1);
4150 return sv->attr(select.range(), method, select.selector().name());
4151 }
4152 case TK_APPLY: {
4153 auto apply = Apply(tree);
4154 return emitApplyExpr(apply, n_binders, type_hint);
4155 } break;
4156 case TK_SUBSCRIPT: {
4157 return emitSubscript(Subscript(tree), type_hint);
4158 } break;
4159 default:
4160 return std::make_shared<SimpleValue>(emitSimpleExpr(tree, type_hint));
4161 }
4162 }
4163
emitUnaryOptorch::jit::to_ir4164 Value* emitUnaryOp(
4165 const TreeRef& tree,
4166 const std::string& magicMethod,
4167 const c10::Symbol& opSymbol) {
4168 const auto& inputs = tree->trees();
4169 auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
4170 auto val =
4171 asSimple(makeMagic(
4172 magicMethod,
4173 std::make_shared<BuiltinFunction>(opSymbol, std::nullopt))
4174 ->call(tree->range(), method, named_values, {}, 0));
4175
4176 // if we emitted the unary op and not some other overloaded function,
4177 // then try to constantfold
4178 if (val->node()->kind() != opSymbol) {
4179 return val;
4180 }
4181
4182 auto maybe_out_stack = runNodeIfInputsAreConstant(val->node());
4183 if (!maybe_out_stack) {
4184 return val;
4185 }
4186 TORCH_INTERNAL_ASSERT(maybe_out_stack->size() == 1);
4187 return graph->insertConstant(maybe_out_stack->at(0), tree->range());
4188 }
4189
4190 /**
4191 * Emit a fork expression, of the form:
4192 * torch.jit.fork(forked, *args, **kwargs)
4193 */
emitForkExprtorch::jit::to_ir4194 std::shared_ptr<SugaredValue> emitForkExpr(
4195 SourceRange loc,
4196 const std::shared_ptr<SugaredValue>& forked,
4197 at::ArrayRef<NamedValue> args,
4198 at::ArrayRef<NamedValue> kwargs) {
4199 auto g = method.graph();
4200 TypePtr out_type;
4201
4202 auto fork_node = g->insertNode(method.graph()->create(prim::forkClosure, 1))
4203 ->setSourceRange(loc);
4204
4205 // We create a fork by emitting a closure and setting the closure output
4206 // into the fork input. If a closure doesn't already exist, we create one.
4207 {
4208 WithInsertPoint insert(fork_node);
4209 if (ClosureValue* sv = dynamic_cast<ClosureValue*>(forked.get())) {
4210 Value* closure_output = sv->asValue(loc, method);
4211 Block* closure_block = closure_output->node()->blocks().at(0);
4212 TORCH_INTERNAL_ASSERT(closure_block->outputs().size() == 1);
4213 out_type = closure_block->outputs().at(0)->type();
4214 fork_node->addInput(closure_output);
4215 } else {
4216 auto emit_closure_body = [&](Block* closure_block) {
4217 auto fn_sugared_output = forked->call(loc, method, args, kwargs, 1);
4218 auto fn_simple_output = fn_sugared_output->asValue(loc, method);
4219 closure_block->registerOutput(fn_simple_output);
4220 out_type = fn_simple_output->type();
4221 };
4222 auto closure_value = emitClosure(emit_closure_body);
4223 fork_node->addInput(closure_value->asValue(loc, method));
4224 }
4225 }
4226 Value* node_output =
4227 fork_node->output()->setType(FutureType::create(out_type));
4228 return std::make_shared<SimpleValue>(node_output);
4229 }
4230
emitAwaitableExprtorch::jit::to_ir4231 std::shared_ptr<SugaredValue> emitAwaitableExpr(
4232 SourceRange loc,
4233 const std::shared_ptr<SugaredValue>& awaited,
4234 at::ArrayRef<NamedValue> args,
4235 at::ArrayRef<NamedValue> kwargs) {
4236 auto g = method.graph();
4237 TypePtr out_type{};
4238
4239 auto await_node =
4240 g->insertNode(method.graph()->create(prim::awaitableClosure, 1))
4241 ->setSourceRange(loc);
4242
4243 {
4244 WithInsertPoint insert(await_node);
4245 if (auto sv = dynamic_cast<ClosureValue*>(awaited.get())) {
4246 Value* closure_output = sv->asValue(loc, method);
4247 Block* closure_block = closure_output->node()->blocks().at(0);
4248 TORCH_INTERNAL_ASSERT(closure_block->outputs().size() == 1);
4249 out_type = closure_block->outputs().at(0)->type();
4250 await_node->addInput(closure_output);
4251 } else {
4252 auto emit_closure_body = [&](Block* closure_block) {
4253 auto fn_sugared_output = awaited->call(loc, method, args, kwargs, 1);
4254 auto fn_simple_output = fn_sugared_output->asValue(loc, method);
4255 closure_block->registerOutput(fn_simple_output);
4256 out_type = fn_simple_output->type();
4257 };
4258 auto closure_value = emitClosure(emit_closure_body);
4259 await_node->addInput(closure_value->asValue(loc, method));
4260 }
4261 }
4262 Value* node_output =
4263 await_node->output()->setType(AwaitType::create(out_type));
4264 return std::make_shared<SimpleValue>(node_output);
4265 }
4266
emitRpcExprtorch::jit::to_ir4267 std::shared_ptr<SugaredValue> emitRpcExpr(const Apply& apply, Symbol rpc_op) {
4268 // TODO: This is a temporary apporoach to enable calling user fucntion
4269 // through RPC in TorchScript,
4270 // Ideally, function value in JIT IR is first-class citizen and
4271 // The RPC C++ entry API can take c10::Function directly.
4272 size_t rpcMinInputs = 2;
4273 size_t rpcMaxInputs = 5;
4274 std::string op_name = rpc_op.toUnqualString();
4275 if (apply.inputs().size() < rpcMinInputs ||
4276 apply.inputs().size() > rpcMaxInputs) {
4277 throw(
4278 ErrorReport(apply)
4279 << "Possible forms of call to " << op_name << "(..) are\n"
4280 << op_name
4281 << "(dst_worker_name, user_callable, args, kwargs, timeout)\n"
4282 << op_name << "(dst_worker_name, user_callable, args, kwargs)\n"
4283 << op_name << "(dst_worker_name, user_callable, args)\n"
4284 << op_name << "(dst_worker_name, user_callable)\n"
4285 << "Now the number of arguments is " << apply.inputs().size());
4286 }
4287 if (!apply.attributes().empty()) {
4288 throw(
4289 ErrorReport(apply)
4290 << op_name << "(dst_worker_name, user_callable, args, kwargs)"
4291 << "does not support kwargs yet");
4292 }
4293 // TODO: Make rpc_op(..) support taking kwargs,
4294 // like rpc_async(to="worker1", func=my_func, args=(), kwargs={})
4295
4296 auto& input_trees = apply.inputs().tree()->trees();
4297 Value* dst_worker_name_value = emitExpr(Expr(input_trees[0]));
4298 std::shared_ptr<SugaredValue> user_callable_sugared_value =
4299 emitSugaredExpr(Expr(input_trees[1]), 1);
4300 TORCH_CHECK(
4301 user_callable_sugared_value->kind() == "function",
4302 "user_callable should be a FunctionValue, it's now a ",
4303 user_callable_sugared_value->kind())
4304 // NB: This should be done using `std::dynamic_pointer_cast`
4305 // and assert `user_callable_function_value != nullptr`. But somehow on
4306 // macos std::dynamic_pointer_cast always returns
4307 // `user_callable_function_value` as a `nullptr`, even if
4308 // `user_callable_sugared_value->kind() == "function"`.
4309 std::shared_ptr<FunctionValue> user_callable_function_value =
4310 std::static_pointer_cast<FunctionValue>(user_callable_sugared_value);
4311 // If `kwargs` is an empty dict, users are allowed to not pass `kwargs`.
4312 // If `args` and `kwargs` are an empty tuple and an empty dict,
4313 // respectively, users are allowed to not pass `args` and `kwargs`.
4314
4315 TreeList args_kwargs_timeout_trees(
4316 input_trees.begin() + 2, input_trees.end());
4317
4318 // Get user callable.
4319 const auto& callablePtrs = user_callable_function_value->callees();
4320 TORCH_INTERNAL_ASSERT(
4321 callablePtrs.size() == 1,
4322 "User-provided callable size should be 1. Now it's",
4323 callablePtrs.size())
4324 Function* callablePtr = callablePtrs.at(0);
4325
4326 const auto& functionSchema = callablePtr->getSchema();
4327 const SourceRange& loc = apply.range();
4328 auto graphPtr = method.graph();
4329
4330 // Match FunctionSchema.
4331 std::vector<NamedValue> args;
4332 std::vector<NamedValue> kwargs;
4333 // Get args and kwargs as `NamedValue`s.
4334 // Similar to getNamedValues(..) and emitAttributes(..).
4335 if (!args_kwargs_timeout_trees.empty()) {
4336 // Unroll args from a Var that is known to be a Tuple.
4337 auto& args_tree = args_kwargs_timeout_trees[0];
4338 auto entry_sugared_values = emitSugaredExpr(Expr(args_tree), 1)
4339 ->asTuple(args_tree->range(), method);
4340 args.reserve(entry_sugared_values.size());
4341 for (const auto& entrie_sugared_value : entry_sugared_values) {
4342 args.emplace_back(
4343 args_tree->range(),
4344 entrie_sugared_value->asValue(args_tree->range(), method));
4345 }
4346 // NB: Can't do schema check on kwargs, given the RPC API is
4347 // rpc_op(to, user_callable, args, kwargs),
4348 // users can construct kwargs = {"first" + "_arg" : 1}.
4349 // Notice the key is determined at run time.
4350 // We can do it at compile time, unless one day the RPC API is
4351 // rpc_op(to, user_callable, arg_0, arg_1, kwarg_0="foo",
4352 // kwarg_1="bar")
4353 }
4354 matchSchema(functionSchema, loc, *graphPtr, args, kwargs);
4355
4356 // Graph insert the QualifiedName as an constant input IR Value.
4357 const auto& qualname = callablePtr->qualname();
4358 IValue userCallableQualNameIValue(qualname.qualifiedName());
4359 Value* userCallableQualNameValue =
4360 graphPtr->insertConstant(userCallableQualNameIValue, loc);
4361
4362 // Graph insert the corresponding RPC node to the graph.
4363 Node* rpc_node =
4364 graphPtr->insertNode(graphPtr->create(rpc_op, 1))->setSourceRange(loc);
4365 {
4366 WithInsertPoint insert(rpc_node);
4367 rpc_node->addInput(dst_worker_name_value);
4368 rpc_node->addInput(userCallableQualNameValue);
4369
4370 for (const auto& tree : args_kwargs_timeout_trees) {
4371 rpc_node->addInput(emitExpr(Expr(tree)));
4372 }
4373 }
4374 Value* rpc_node_output = rpc_node->output();
4375
4376 // Set output type from FunctionSchema and corresponding rpc_op.
4377 const std::vector<Argument>& returns = functionSchema.returns();
4378 TORCH_INTERNAL_ASSERT(returns.size() == 1);
4379 TypePtr output_type = nullptr;
4380 if (rpc_op == prim::rpc_async) {
4381 // rpc_async returns FutureType of the functionSchema's return type
4382 output_type = FutureType::create(returns[0].type());
4383 } else if (rpc_op == prim::rpc_sync) {
4384 // rpc_sync returns the functionSchema's return type
4385 output_type = returns[0].type();
4386 } else if (rpc_op == prim::rpc_remote) {
4387 // rpc_remote returns RRefType of the functionSchema's return type
4388 output_type = RRefType::create(returns[0].type());
4389 } else {
4390 throw(
4391 ErrorReport(apply)
4392 << rpc_op.toDisplayString() << " is not supported in TorchScript!'");
4393 }
4394 rpc_node_output->setType(output_type);
4395 return std::make_shared<SimpleValue>(rpc_node_output);
4396 }
4397
emitBinaryOptorch::jit::to_ir4398 Value* emitBinaryOp(const TreeRef& tree) {
4399 const auto& inputs = tree->trees();
4400 auto kind = getNodeKind(tree->kind(), inputs.size());
4401 auto overload = getOperatorOverload(tree->kind(), inputs.size());
4402 auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
4403 if (tree->kind() == TK_IN) {
4404 // For `in` the arguments are in reverse order (the object being
4405 // checked is second)
4406 std::iter_swap(named_values.begin() + 0, named_values.begin() + 1);
4407 }
4408
4409 // if this is adding two tuples, we deal with it here.
4410 // the reason is we can't specify the length of tuples
4411 // when registering custom aten::add.
4412 if (named_values[0].type()->kind() == TupleType::Kind &&
4413 named_values[1].type()->kind() == TupleType::Kind &&
4414 kind == aten::add) {
4415 auto first_tuple = createTupleUnpack(named_values[0].value(*graph)).vec();
4416 auto second_tuple =
4417 createTupleUnpack(named_values[1].value(*graph)).vec();
4418 first_tuple.insert(
4419 first_tuple.end(), second_tuple.begin(), second_tuple.end());
4420 return graph->insertNode(graph->createTuple(first_tuple))->output();
4421 }
4422
4423 return asSimple(
4424 makeMagic(
4425 overload, std::make_shared<BuiltinFunction>(kind, std::nullopt))
4426 ->call(tree->range(), method, named_values, {}, 0));
4427 }
4428
emitListLiteraltorch::jit::to_ir4429 Value* emitListLiteral(const ListLiteral& ll, const TypePtr& type_hint) {
4430 auto values = getValues(ll.inputs(), /*maybe_unpack=*/true);
4431
4432 // Empty List Literals that are not assigned to variables
4433 // may match to any list type in schema matching,
4434 // but still default to List[Tensor] if assigned to a variable
4435 // or returned from a function
4436 // Restricting empty list matching to temporary values
4437 // avoids difficult to handle cases such as
4438 // a = []
4439 // b = a
4440 // if cond:
4441 // b.append(2)
4442 // else:
4443 // a.append("hi")
4444 // This is also the same behavior that C++ allows with {}
4445 // (cannot assign to a variable typed as auto)
4446 // These nodes will be removed in a later pass after initial compilation
4447 if (values.empty() && type_hint == nullptr) {
4448 auto node = graph->insertNode(graph->create(prim::EmptyListLiteral));
4449 node->output()->setType(ListType::ofTensors());
4450 return node->output();
4451 }
4452
4453 // Determine the element type of the list. If we have a type hint
4454 // of `List[T]`, use `T`. If the list is non-empty, find the
4455 // greatest common supertype of all the list elements (defaulting to
4456 // `Any` as a catch-all supertype). Assume `[]` is `List[Tensor]`
4457 TypePtr inferred_elem_type = TensorType::get();
4458
4459 TypePtr refined_type_hint = type_hint;
4460
4461 // If `type_hint` is a Union/Optional, we're going to change it to
4462 // be the type of the rhs List, so we need to store the original
4463 // UnionType for later. `nullptr` means that we don't need to emit
4464 // an `unchecked_cast` node (either because we don't have a type
4465 // hint or because the type hint wasn't a Union)
4466 TypePtr annotated_union_type =
4467 refined_type_hint && refined_type_hint->isUnionType()
4468 ? refined_type_hint
4469 : nullptr;
4470
4471 // This is used in the case that we have a Union annotation that
4472 // contains multiple Lists
4473 std::vector<TypePtr> all_candidates = {};
4474
4475 if (refined_type_hint) {
4476 auto do_if_type_match = [&]() {
4477 auto list_type_hint = refined_type_hint->cast<ListType>();
4478 inferred_elem_type = list_type_hint->getElementType();
4479 };
4480
4481 auto type_match = [&](const TypePtr& t) {
4482 return t->isSubtypeOf(AnyListType::get());
4483 };
4484
4485 refineAndSetUnionTypeHintOrPopulateCandidatesVector(
4486 type_hint,
4487 &refined_type_hint,
4488 &all_candidates,
4489 "List",
4490 ll,
4491 type_match,
4492 do_if_type_match,
4493 do_if_type_match);
4494
4495 if (!all_candidates.empty() && values.empty()) {
4496 throw(
4497 ErrorReport(ll)
4498 << "Cannot assign an empty list to a "
4499 << "variable annotated to be type " << refined_type_hint->repr_str()
4500 << " because there are multiple possible List "
4501 << "type candidates in the Union annotation");
4502 }
4503 }
4504
4505 if (!values.empty()) {
4506 auto types = fmap(values, [](const Value* v) { return v->type(); });
4507
4508 std::stringstream nowhere; // never used
4509
4510 // We don't want to use `elem_type` as the final argument to
4511 // `unifyTypeList` because there's a chance that `elem_type` is
4512 // the Tensor default
4513 const auto elem_type_hint =
4514 refined_type_hint && refined_type_hint->kind() == ListType::Kind
4515 ? refined_type_hint->cast<ListType>()->getElementType()
4516 : nullptr;
4517
4518 std::optional<TypePtr> unified_elem_type = unifyTypeList(
4519 types, nowhere, /*default_to_union=*/true, elem_type_hint);
4520
4521 if (!refined_type_hint &&
4522 (*unified_elem_type)->kind() == UnionType::Kind) {
4523 TORCH_WARN(
4524 "List consists of heterogeneous types, which means",
4525 " that it has been typed as containing ",
4526 (*unified_elem_type)->repr_str(),
4527 ". To use any of the "
4528 "values in this List, it will be necessary to add an "
4529 "`assert isinstance` statement before first use to trigger "
4530 "type refinement.\n",
4531 ll.range().str());
4532 }
4533
4534 if (all_candidates.empty() && refined_type_hint &&
4535 !(*unified_elem_type)->isSubtypeOf(*inferred_elem_type)) {
4536 throw(
4537 ErrorReport(ll)
4538 << "List type annotation `" << refined_type_hint->repr_str()
4539 << "` did not match the types of the given list elements,"
4540 << " which were unified to " << (*unified_elem_type)->repr_str());
4541 }
4542
4543 if (!all_candidates.empty()) {
4544 refineAndSetListTypeHintFromCandidatesVector(
4545 all_candidates,
4546 type_hint,
4547 &refined_type_hint,
4548 *unified_elem_type,
4549 ll);
4550 inferred_elem_type =
4551 refined_type_hint->expect<ListType>()->getElementType();
4552 }
4553
4554 // We only want to set `elem_type` if we don't have a type hint
4555 // to allow for the case that `*unified` is a subtype of
4556 // `type_hint`
4557 if (!refined_type_hint) {
4558 inferred_elem_type = *unified_elem_type;
4559 }
4560 }
4561
4562 Node* result =
4563 graph->insertNode(graph->createList(inferred_elem_type, values));
4564 if (annotated_union_type) {
4565 Node* n = graph->insertNode(
4566 graph->create(prim::unchecked_cast, {result->output()}));
4567 n->output()->setType(std::move(annotated_union_type));
4568 result = n;
4569 }
4570
4571 return result->output();
4572 }
4573
emitDictLiteraltorch::jit::to_ir4574 Value* emitDictLiteral(DictLiteral dl, const TypePtr& type_hint) {
4575 auto key_trees = dl.key_inputs().tree()->trees();
4576 auto value_trees = dl.value_inputs().tree()->trees();
4577
4578 AT_ASSERT(key_trees.size() == value_trees.size());
4579
4580 std::vector<Value*> keys, values;
4581 TypePtr rhs_value_type;
4582
4583 for (const auto i : c10::irange(key_trees.size())) {
4584 keys.push_back(emitExpr(Expr(key_trees[i])));
4585 values.push_back(emitExpr(Expr(value_trees[i])));
4586
4587 if (i == 0) {
4588 rhs_value_type = values[i]->type();
4589 } else {
4590 if (keys[i - 1]->type()->kind() != keys[i]->type()->kind()) {
4591 throw(
4592 ErrorReport(key_trees[i])
4593 << "Dict keys must contain "
4594 << "only a single type. Expected: "
4595 << keys[i - 1]->type()->repr_str() << " but found "
4596 << keys[i]->type()->repr_str() << " instead");
4597 }
4598 rhs_value_type = *(unifyTypes(
4599 rhs_value_type, values[i]->type(), /*default_to_union=*/true));
4600 }
4601 }
4602
4603 TypePtr refined_type_hint = type_hint;
4604
4605 TypePtr annotated_union_type =
4606 type_hint && type_hint->isUnionType() ? type_hint : nullptr;
4607
4608 std::vector<TypePtr> all_candidates = {};
4609
4610 auto default_refined_type_hint_setter = [&]() {
4611 if (keys.empty()) {
4612 refined_type_hint =
4613 DictType::create(StringType::get(), TensorType::get());
4614 } else {
4615 refined_type_hint =
4616 DictType::create(keys.at(0)->type(), rhs_value_type);
4617 if (rhs_value_type->kind() == UnionType::Kind) {
4618 TORCH_WARN(
4619 "Dict values consist of heterogeneous types, which means",
4620 " that the dict has been typed as containing ",
4621 refined_type_hint->repr_str(),
4622 ". To use any of the values in this Dict, it will be "
4623 "necessary to add an `assert isinstance` statement before "
4624 "first use to trigger type refinement.\n",
4625 dl.range().str());
4626 }
4627 }
4628 };
4629
4630 if (type_hint) {
4631 auto type_match = [&](const TypePtr& t) {
4632 return t->kind() == DictType::Kind;
4633 };
4634
4635 refineAndSetUnionTypeHintOrPopulateCandidatesVector(
4636 type_hint,
4637 &refined_type_hint,
4638 &all_candidates,
4639 "Dict",
4640 dl,
4641 type_match,
4642 [] {},
4643 default_refined_type_hint_setter);
4644
4645 if (!all_candidates.empty() && values.empty()) {
4646 throw(
4647 ErrorReport(dl)
4648 << "Cannot assign an empty dict to a "
4649 << "variable annotated to be type " << type_hint->repr_str()
4650 << " because there are multiple possible Dict "
4651 << "type candidates in the Union annotation");
4652 }
4653 } else {
4654 default_refined_type_hint_setter();
4655 }
4656
4657 // We must have either a) specific key/value types already, or b) a
4658 // list of possible candidates
4659 TORCH_INTERNAL_ASSERT(!all_candidates.empty() || refined_type_hint);
4660
4661 if (!values.empty()) {
4662 if (!all_candidates.empty()) {
4663 refineAndSetDictTypeHintFromCandidatesVector(
4664 all_candidates,
4665 type_hint,
4666 &refined_type_hint,
4667 keys[0]->type(),
4668 rhs_value_type,
4669 dl);
4670 }
4671
4672 if (refined_type_hint->expect<DictType>()->getKeyType() !=
4673 keys.at(0)->type()) {
4674 throw(
4675 ErrorReport(dl)
4676 << "Type annotation was inferred to be "
4677 << refined_type_hint->repr_str()
4678 << "but the type of keys given by the dict literal is "
4679 << keys.at(0)->type()->repr_str());
4680 }
4681
4682 if (!rhs_value_type->isSubtypeOf(
4683 refined_type_hint->expect<DictType>()->getValueType())) {
4684 throw(
4685 ErrorReport(dl)
4686 << "Type annotation was inferred to be `"
4687 << refined_type_hint->repr_str()
4688 << "`, but the type of values given by the dict literal is "
4689 << rhs_value_type->repr_str());
4690 }
4691 }
4692
4693 Node* result = graph->insertNode(graph->createDict(
4694 refined_type_hint->expect<DictType>()->getKeyType(),
4695 refined_type_hint->expect<DictType>()->getValueType(),
4696 keys,
4697 values));
4698 if (annotated_union_type) {
4699 Node* n = graph->insertNode(
4700 graph->create(prim::unchecked_cast, {result->output()}));
4701 n->output()->setType(std::move(annotated_union_type));
4702 result = n;
4703 }
4704
4705 return result->output();
4706 }
4707
emitSimpleExprtorch::jit::to_ir4708 Value* emitSimpleExpr(
4709 const TreeRef& tree,
4710 const TypePtr& type_hint = nullptr) {
4711 switch (tree->kind()) {
4712 case TK_FLOOR_DIV:
4713 case '@': {
4714 const auto& inputs = tree->trees();
4715 auto kind = getNodeKind(tree->kind(), inputs.size());
4716 auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
4717 return emitBuiltinCall(
4718 tree->range(), *method.graph(), kind, named_values, {});
4719 }
4720 case '%': {
4721 auto lhs = emitSugaredExpr(Expr(tree->tree(0)), 0)
4722 ->asValue(tree->tree(0)->range(), method);
4723 auto const& lhs_type = lhs->type();
4724 if (lhs_type == StringType::get()) {
4725 auto values = getValues(tree->trees(), /*maybe_unpack=*/false);
4726 auto node = graph->create(aten::percentFormat, values, 1)
4727 ->setSourceRange(tree->range());
4728 Value* output = graph->insertNode(node)->output();
4729 output->setType(StringType::get());
4730 return output;
4731 } else {
4732 return emitBinaryOp(tree);
4733 }
4734 }
4735 case TK_IN:
4736 case TK_POW:
4737 case TK_NE:
4738 case TK_EQ:
4739 case '<':
4740 case '>':
4741 case TK_LE:
4742 case TK_GE:
4743 case '*':
4744 case '/':
4745 case '+':
4746 case '-':
4747 case '&':
4748 case '|':
4749 case '^':
4750 case TK_LSHIFT:
4751 case TK_RSHIFT:
4752 return emitBinaryOp(tree);
4753 case TK_IS:
4754 case TK_ISNOT:
4755 case TK_AND:
4756 case TK_OR:
4757 case TK_NOT: {
4758 return emitCondExpr(Expr(tree)).value();
4759 }
4760 case TK_UNARY_MINUS: {
4761 return emitUnaryOp(tree, "__neg__", aten::neg);
4762 }
4763 case '~': {
4764 return emitUnaryOp(tree, "__invert__", aten::bitwise_not);
4765 }
4766 case TK_STARRED: {
4767 throw(
4768 ErrorReport(tree)
4769 << "Unexpected starred expansion. File a bug report");
4770 }
4771 case TK_CONST: {
4772 return emitConst(Const(tree));
4773 } break;
4774 case TK_TRUE: {
4775 return graph->insertConstant(true, tree->range());
4776 } break;
4777 case TK_FALSE: {
4778 return graph->insertConstant(false, tree->range());
4779 } break;
4780 case TK_NONE: {
4781 return graph->insertConstant(IValue(), tree->range());
4782 } break;
4783 case TK_IF_EXPR: {
4784 return emitTernaryIf(TernaryIf(tree), type_hint);
4785 } break;
4786 case TK_STRINGLITERAL: {
4787 return emitStringLiteral(StringLiteral(tree));
4788 } break;
4789 case TK_LIST_LITERAL: {
4790 auto ll = ListLiteral(tree);
4791 return emitListLiteral(ll, type_hint);
4792 } break;
4793 case TK_TUPLE_LITERAL: {
4794 auto ll = TupleLiteral(tree);
4795 auto values = getValues(ll.inputs(), /*maybe_unpack=*/true);
4796 return graph->insertNode(graph->createTuple(values))->output();
4797 } break;
4798 case TK_DICT_LITERAL: {
4799 auto dc = DictLiteral(tree);
4800 return emitDictLiteral(dc, type_hint);
4801 } break;
4802 case TK_LIST_COMP: {
4803 auto lc = ListComp(tree);
4804 return emitListComprehension(lc, type_hint);
4805 } break;
4806 case TK_DICT_COMP: {
4807 auto dc = DictComp(tree);
4808 return emitDictComprehension(dc, type_hint);
4809 } break;
4810 default:
4811 throw(ErrorReport(tree) << "Cannot emit expr for: " << tree);
4812 }
4813 }
4814
emitConsttorch::jit::to_ir4815 Value* emitConst(const Const& c) {
4816 if (c.isFloatingPoint())
4817 return materializeConstant(
4818 c.asFloatingPoint(), *graph, c.range(), fp_constants);
4819 else if (c.isComplex())
4820 return materializeConstant(
4821 c.asComplex(), *graph, c.range(), complex_constants);
4822 else
4823 return materializeConstant(
4824 c.asIntegral(), *graph, c.range(), integral_constants);
4825 }
4826
emitStringLiteraltorch::jit::to_ir4827 Value* emitStringLiteral(const StringLiteral& c) {
4828 return insertConstant(*graph, c.text(), c.range());
4829 }
4830
4831 // Desugars select indexing: tensor[i] -> tensor.select(dim, i)
emitSelecttorch::jit::to_ir4832 Value* emitSelect(
4833 const SourceRange& loc,
4834 Value* input,
4835 Value* dim,
4836 Value* index) {
4837 return emitBuiltinCall(loc, *graph, aten::select, {input, dim, index}, {});
4838 }
4839
emitSliceOptorch::jit::to_ir4840 Value* emitSliceOp(
4841 const SourceRange& loc,
4842 Value* sliceable,
4843 Value* dim,
4844 Value* start,
4845 Value* end,
4846 Value* step) {
4847 std::vector<NamedValue> args;
4848 args.reserve(5);
4849 args.emplace_back(loc, "self", sliceable);
4850
4851 // XXX: If list slicing becomes more complicated or stops using
4852 // aten::slice, we should separate it from this function.
4853 if (dim) {
4854 AT_ASSERT(sliceable->type()->isSubtypeOf(*TensorType::get()));
4855
4856 args.emplace_back(dim);
4857 } else {
4858 AT_ASSERT(!sliceable->type()->isSubtypeOf(*TensorType::get()));
4859 }
4860
4861 if (sliceable->type()->cast<TupleType>()) {
4862 std::vector<std::optional<NamedValue>> tuple_args;
4863 // since we are only dealing with tuple slicing, we try to keep
4864 // tuple args separate for now
4865 tuple_args.reserve(3);
4866
4867 start ? tuple_args.emplace_back(start)
4868 : tuple_args.emplace_back(std::nullopt);
4869 end ? tuple_args.emplace_back(end)
4870 : tuple_args.emplace_back(std::nullopt);
4871 step ? tuple_args.emplace_back(step)
4872 : tuple_args.emplace_back(std::nullopt);
4873
4874 return emitTupleSlice(loc, args[0], tuple_args);
4875 }
4876
4877 // handling cases like x[0:2]. x[0:2:] is already handled from python
4878 if (!step) {
4879 step = graph->insertConstant(1, loc);
4880 }
4881
4882 args.emplace_back(loc, "start", start);
4883 args.emplace_back(loc, "end", end);
4884 args.emplace_back(loc, "step", step);
4885 return emitBuiltinCall(loc, *graph, aten::slice, args, {});
4886 }
4887
4888 // Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end,
4889 // 1)
emitSlicetorch::jit::to_ir4890 Value* emitSlice(
4891 const SourceRange& loc,
4892 Value* input,
4893 Value* dim, // Only used for tensor slicing
4894 const SliceExpr& slice) {
4895 Value* start = nullptr;
4896 Value* end = nullptr;
4897 Value* step = nullptr;
4898 if (slice.start().present()) {
4899 start = emitExpr(Expr(slice.start().get()));
4900 }
4901 if (slice.end().present()) {
4902 end = emitExpr(Expr(slice.end().get()));
4903 }
4904 if (slice.step().present()) {
4905 step = emitExpr(Expr(slice.step().get()));
4906 }
4907 return emitSliceOp(loc, input, dim, start, end, step);
4908 }
4909
emitUnsqueezetorch::jit::to_ir4910 Value* emitUnsqueeze(const SourceRange& loc, Value* input, Value* dim_val) {
4911 return emitBuiltinCall(loc, *graph, aten::unsqueeze, {input, dim_val}, {});
4912 }
4913
emitIndextorch::jit::to_ir4914 Value* emitIndex(
4915 const SourceRange& loc,
4916 Value* input,
4917 at::ArrayRef<Value*> indices) {
4918 // NB: the index of aten::index should be a type of List[Optional[Tensor]],
4919 // this is to support the case like t[:, :, 1] where : here indicates a
4920 // None/undefined tensor(optional tensor)
4921 auto* index =
4922 graph->insertNode(graph->createList(OptionalType::ofTensor(), indices))
4923 ->output();
4924 return emitBuiltinCall(loc, *graph, aten::index, {input, index}, {});
4925 }
4926
4927 // Emits multidimensional slicing with int and slice indices.
4928 // Returns:
4929 // - Value*: the input after it has been indexed by int and slice indices.
4930 // - vector<Value*>: A list of tensor Value* indices that have not been
4931 // applied yet.
4932 // Should be NULL at indices where sliceable (post-slicing) isn't indexed by
4933 // a tensor.
emitIntAndSliceIndexingtorch::jit::to_ir4934 std::pair<Value*, std::vector<Value*>> emitIntAndSliceIndexing(
4935 const SourceRange& loc,
4936 Value* sliceable,
4937 const List<Expr>& subscript_exprs) {
4938 // Overall, to handle indexing (other than Tensors), we need to handle a
4939 // couple different things. For example, for x[1:3, None, 4], each of these
4940 // different index types (slice, None, and integer) result in different
4941 // number of dimensions. Slicing doesn't change the number of dimensions,
4942 // None adds a dimension, and integer removes a dimension. As these indexing
4943 // operations are applied left to right, the actual index that it's being
4944 // applied to depends on the previous operations. Ellipses indexing throws
4945 // another wrinkle. Ellipses selects any remaining unspecified dimensions.
4946 // Thus, for indexes following an ellipses, the actual index an indexing
4947 // operation is being applied to depends on the operations to the right.
4948 // Thus, we do two passes, one from left to right up until the ellipses, and
4949 // one from right to left.
4950
4951 std::vector<Value*> tensor_indices;
4952
4953 auto insert_value_for_dim = [&](int64_t dim) {
4954 return graph->insertConstant(dim, loc);
4955 };
4956 std::vector<int64_t> dims(subscript_exprs.size());
4957 std::vector<std::optional<Value*>> exprs(
4958 subscript_exprs.size(), std::nullopt);
4959
4960 auto handle_indexing = [&](const Expr& subscript_expr,
4961 size_t expr_idx,
4962 int64_t dim,
4963 bool is_reverse = false) {
4964 dims[expr_idx] = dim;
4965
4966 // Slice expression case, does not represent a single index.
4967 if (subscript_expr.kind() == TK_SLICE_EXPR) {
4968 if (is_reverse) {
4969 return dim - 1;
4970 } else {
4971 return dim + 1;
4972 }
4973 }
4974
4975 // Slice object case, does not represent a single index.
4976 auto subscript_sv = emitSugaredExpr(subscript_expr, 1);
4977 if (dynamic_cast<SliceValue*>(subscript_sv.get())) {
4978 if (is_reverse) {
4979 return dim - 1;
4980 } else {
4981 return dim + 1;
4982 }
4983 }
4984
4985 TypePtr type_hint;
4986 if (subscript_expr.kind() == TK_NONE) {
4987 type_hint = NoneType::get();
4988 }
4989 auto index = emitExpr(subscript_expr, type_hint);
4990
4991 // Accept list as subscript but convert it to a Tensor
4992 // since it's equivalent to indexing with Tensor.
4993 // The list can be a list literal or list variable.
4994 // Advanced indexing using list:
4995 // @torch.jit.script
4996 // def f(x):
4997 // return x[[0, 1, 5]] # or
4998 // return x[[0, 1], [0, 1]] # or
4999 // return x[[[0, 1], [0, 1]], [[0, 1], [0, 1]]] # or
5000 // ls = [0, 1]
5001 // return x[ls]
5002 // Statements above are equivalent to advanced indexing using Tensor:
5003 // @torch.jit.script
5004 // def f(x):
5005 // return x[torch.tensor([0, 1, 5])] # or
5006 // return x[torch.tensor([0, 1]), torch.tensor([0, 1])] # or
5007 // return x[torch.tensor([[0, 1], [0, 1]]),
5008 // torch.tensor([[0, 1], [0, 1]])] # or
5009 // ls = [0, 1]
5010 // return x[torch.tensor(ls)]
5011 if (index->type()->kind() == c10::TypeKind::ListType) {
5012 // Always create index tensor as LongTensor.
5013 // This is to match Pytorch eager frontend behavior which accepts
5014 // indexing with float list.
5015 index = graph->insert(
5016 aten::tensor, {index}, {NamedValue("dtype", c10::kLong)});
5017 }
5018
5019 exprs[expr_idx] = index;
5020 if (index->type()->isSubtypeOf(*NoneType::get())) {
5021 if (is_reverse) {
5022 return dim;
5023 } else {
5024 return dim + 1;
5025 }
5026 } else if (index->type() == IntType::get()) {
5027 if (is_reverse) {
5028 return dim - 1;
5029 } else {
5030 return dim;
5031 }
5032 } else if (index->type()->isSubtypeOf(*OptionalType::ofTensor())) {
5033 if (is_reverse) {
5034 throw(
5035 ErrorReport(loc)
5036 << "Ellipses followed by tensor indexing is currently not supported");
5037 } else {
5038 return dim + 1;
5039 }
5040 } else {
5041 throw(
5042 ErrorReport(loc)
5043 << "Unsupported operation: indexing tensor with unsupported index type '"
5044 << index->type()->repr_str()
5045 << "'. Only ints, slices, lists and tensors are supported");
5046 }
5047 };
5048
5049 size_t idx = 0;
5050 int64_t dim = 0;
5051 for (; idx < subscript_exprs.size(); idx++) {
5052 auto subscript_expr = subscript_exprs[idx];
5053 if (subscript_expr.kind() == TK_DOTS) {
5054 break;
5055 }
5056 dim = handle_indexing(subscript_expr, idx, dim, /*is_reverse=*/false);
5057 }
5058 int64_t rdim = -1;
5059 for (size_t rev_idx = subscript_exprs.size() - 1; rev_idx > idx;
5060 rev_idx--) {
5061 auto subscript_expr = subscript_exprs[rev_idx];
5062 if (subscript_expr.kind() == TK_DOTS) {
5063 throw(
5064 ErrorReport(loc)
5065 << "An index can only have a single ellipsis ('...')");
5066 }
5067 rdim =
5068 handle_indexing(subscript_expr, rev_idx, rdim, /*is_reverse=*/true);
5069 }
5070 for (const auto i : c10::irange(exprs.size())) {
5071 if (!exprs[i].has_value()) {
5072 if (subscript_exprs[i].kind() == TK_SLICE_EXPR) {
5073 sliceable = emitSlice(
5074 loc,
5075 sliceable,
5076 insert_value_for_dim(dims[i]),
5077 SliceExpr(subscript_exprs[i]));
5078 continue;
5079 }
5080
5081 if (subscript_exprs[i].kind() == TK_DOTS) {
5082 continue;
5083 }
5084
5085 auto subscript_sv = emitSugaredExpr(subscript_exprs[i], 1);
5086 if (const auto slice_value =
5087 dynamic_cast<SliceValue*>(subscript_sv.get())) {
5088 sliceable = emitSliceOp(
5089 loc,
5090 sliceable,
5091 insert_value_for_dim(dims[i]),
5092 slice_value->start(),
5093 slice_value->stop(),
5094 slice_value->step());
5095 }
5096
5097 continue;
5098 }
5099 auto expr = exprs[i].value();
5100 if (expr->type()->isSubtypeOf(*NoneType::get())) {
5101 sliceable =
5102 emitUnsqueeze(loc, sliceable, insert_value_for_dim(dims[i]));
5103 } else if (expr->type() == IntType::get()) {
5104 sliceable =
5105 emitSelect(loc, sliceable, insert_value_for_dim(dims[i]), expr);
5106 } else if (expr->type()->isSubtypeOf(*OptionalType::ofTensor())) {
5107 tensor_indices.resize(dims[i] + 1);
5108 tensor_indices[dims[i]] = expr;
5109 } else {
5110 TORCH_INTERNAL_ASSERT(
5111 false, "Trying to process index type that we don't support.");
5112 }
5113 }
5114 // at::index takes in a List[Optional[Tensor]] where some dims can be None.
5115 // create None node with optional tensor output type and pass to at::index.
5116 for (auto& index : tensor_indices) {
5117 if (index == nullptr) {
5118 index = graph->insertNode(graph->createNone())->output();
5119 }
5120 }
5121 return std::make_pair(sliceable, tensor_indices);
5122 }
5123
5124 // Desugars multidim slicing into slice/select/index/unsqueeze calls.
5125 //
5126 // XXX: Errors in user code are not elegantly reported.
5127 // Let's say someone were to do the following:
5128 // @torch.jit.script
5129 // def fn(x):
5130 // return x[0, 1]
5131 // fn(torch.randn(5))
5132 // Because we desugar this into two aten::select ops, the error message
5133 // complains about aten::select failing rather than there "not being
5134 // enough dimensions to index".
5135 //
5136 // The strategy is to slice and select the tensor for int and slices first
5137 // in one pass and then apply at::index on the result of the
5138 // slicing/selecting. Call the tensor after we've applied slice / select the
5139 // `sliced`. tensor_indices should have the same size as sliced.dim():
5140 // - tensor_indices[i] = NULL if we should not index `sliced` at dim i
5141 // - tensor_indices[i] = t if we should index `sliced` at dim i with tensor t.
emitMultidimSlicingtorch::jit::to_ir5142 Value* emitMultidimSlicing(
5143 const SourceRange& loc,
5144 Value* sliceable,
5145 const List<Expr>& subscript_exprs) {
5146 if (!sliceable->type()->isSubtypeOf(*TensorType::get())) {
5147 throw(
5148 ErrorReport(loc)
5149 << "Unsupported operation: attempted to use multidimensional "
5150 << "indexing on a non-tensor type");
5151 }
5152
5153 std::vector<Value*> tensor_indices;
5154 std::tie(sliceable, tensor_indices) =
5155 emitIntAndSliceIndexing(loc, sliceable, subscript_exprs);
5156
5157 if (tensor_indices.empty()) {
5158 // XXX: Might need to at::alias this when we support mutability
5159 return sliceable;
5160 }
5161
5162 return emitIndex(loc, sliceable, tensor_indices);
5163 }
5164
5165 // Desugars slice syntactic sugar tensor[begin:end] -> tensor.slice(begin,
5166 // end).
emitBasicSlicetorch::jit::to_ir5167 Value* emitBasicSlice(
5168 const SourceRange& loc,
5169 Value* sliceable,
5170 const List<Expr>& subscript_exprs) {
5171 AT_ASSERT(subscript_exprs.size() == 1);
5172 AT_ASSERT(subscript_exprs[0].kind() == TK_SLICE_EXPR);
5173 auto slice_exp = SliceExpr(subscript_exprs[0]);
5174 Value* maybe_dim = nullptr;
5175 if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
5176 // If the sliceable object is a tensor, specify a default dimension
5177 maybe_dim = graph->insertConstant(0, loc);
5178 }
5179 return emitSlice(loc, sliceable, maybe_dim, slice_exp);
5180 }
5181
getAdjTupleIndextorch::jit::to_ir5182 int64_t getAdjTupleIndex(
5183 const SourceRange& loc,
5184 const TupleTypePtr& tuple_type,
5185 int64_t input_index,
5186 bool allow_out_of_bounds) {
5187 // set index to be positive to simplify logic in runtime
5188 int64_t adj_index = input_index;
5189 int64_t tuple_len = static_cast<int64_t>(tuple_type->elements().size());
5190 if (input_index < 0) {
5191 adj_index = tuple_len + input_index;
5192 }
5193 if (!allow_out_of_bounds && (adj_index >= tuple_len || adj_index < 0)) {
5194 throw(
5195 ErrorReport(loc) << "Tuple index out of range. Tuple is length "
5196 << tuple_len << " and index is " << input_index);
5197 }
5198 return adj_index;
5199 }
5200
5201 // When a list is marked const in a module, it gets converted to a tuple.
5202 // The result is indexing into a Tuple which contains only one type
5203 // is quite common. since indexing will likely be done in a for loop,
5204 // we do not want to invoke the overhead of converting the tuple to a list
5205 // each iter.
emitTupleIndextorch::jit::to_ir5206 Value* emitTupleIndex(
5207 const SourceRange& loc,
5208 Value* tuple_val,
5209 Value* idx_val) {
5210 auto tuple_typ = tuple_val->type()->cast<TupleType>();
5211 auto elems = tuple_typ->elements();
5212 TypePtr output_type;
5213 if (idx_val->type() != IntType::get()) {
5214 throw(ErrorReport(loc) << "tuple index must be an integer");
5215 }
5216 auto idx = toIValue(idx_val);
5217 if (!idx) {
5218 if (elems.empty() ||
5219 !convertibleToList(tuple_typ, ListType::create(elems[0]))) {
5220 throw(
5221 ErrorReport(loc)
5222 << "Cannot index into a " << tuple_typ->repr_str()
5223 << " with a non-integer literal because we cannot resolve the output type");
5224 }
5225 output_type = elems[0];
5226 } else {
5227 auto adj_index = getAdjTupleIndex(
5228 loc, tuple_typ, idx->toInt(), /*allow_out_of_bounds*/ false);
5229 output_type = elems[adj_index];
5230 }
5231 return graph
5232 ->insertNode(graph->createTupleIndex(tuple_val, idx_val, output_type))
5233 ->output();
5234 }
5235
getSliceIndtorch::jit::to_ir5236 int64_t getSliceInd(Value* idx_val, const SourceRange& loc) {
5237 auto ivalue = toIValue(idx_val);
5238 if (ivalue && ivalue->isInt()) {
5239 return ivalue->to<int64_t>();
5240 } else {
5241 throw(
5242 ErrorReport(loc) << "tuple slice indices must be integer constants");
5243 }
5244 }
5245
emitTupleSlicetorch::jit::to_ir5246 Value* emitTupleSlice(
5247 const SourceRange& loc,
5248 const NamedValue& tuple_val,
5249 const std::vector<std::optional<NamedValue>>& tuple_args) {
5250 auto tuple_type = tuple_val.value(*graph)->type()->expect<TupleType>();
5251 auto tuple_len = tuple_type->elements().size();
5252 auto beg_val = tuple_args[0];
5253 auto end_val = tuple_args[1];
5254 auto step = tuple_args[2];
5255
5256 int64_t step_size = 1;
5257 if (step) {
5258 auto val = toIValue(step->value(*graph));
5259 TORCH_CHECK(val->isInt(), "Step size should always be an integer");
5260 step_size = val->to<int64_t>();
5261 }
5262
5263 int64_t beg = std::numeric_limits<int64_t>::max();
5264 if (beg_val) {
5265 beg = getAdjTupleIndex(
5266 loc, tuple_type, getSliceInd(beg_val->value(*graph), loc), true);
5267 }
5268
5269 int64_t end = std::numeric_limits<int64_t>::max();
5270 if (end_val) {
5271 end = getAdjTupleIndex(
5272 loc, tuple_type, getSliceInd(end_val->value(*graph), loc), true);
5273 }
5274
5275 int64_t num_values = slice_indices_adjust(
5276 static_cast<int64_t>(tuple_len), &beg, &end, step_size);
5277
5278 return graph
5279 ->insertNode(graph->createTupleSlice(
5280 tuple_val.value(*graph), beg, step_size, num_values))
5281 ->output();
5282 }
5283
emitSubscripttorch::jit::to_ir5284 std::shared_ptr<SugaredValue> emitSubscript(
5285 const Subscript& subscript,
5286 TypePtr type_hint = nullptr) {
5287 const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1);
5288 const List<Expr>& subscript_exprs = subscript.subscript_exprs();
5289 const SourceRange& range = subscript.range();
5290 const SourceRange& val_range = subscript.value().range();
5291 if (subscript_exprs.size() != 1) {
5292 return std::make_shared<SimpleValue>(emitMultidimSlicing(
5293 range, sv->asValue(val_range, method), subscript_exprs));
5294 }
5295 if (subscript_exprs[0].kind() == TK_SLICE_EXPR) {
5296 // TODO @wconstab refactor using Symbol instead of string compare
5297 if (sv->kind() == "module") {
5298 // Slicing isn't currently implemented for Sequential/ModuleList,
5299 // but is implemented for Tuples, so a quick workaround is to
5300 // convert to a tuple of Modules for slicing support.
5301 auto s_tuple_val =
5302 sv->asTupleValue(val_range, method)->asValue(val_range, method);
5303 const SliceExpr& slice = SliceExpr(subscript_exprs[0]);
5304 std::vector<std::optional<NamedValue>> tuple_args;
5305 tuple_args.reserve(3);
5306 if (slice.start().present()) {
5307 auto begin = NamedValue(
5308 val_range, "begin", emitExpr(Expr(slice.start().get())));
5309 tuple_args.emplace_back(begin);
5310 } else {
5311 tuple_args.emplace_back(std::nullopt);
5312 }
5313
5314 if (slice.end().present()) {
5315 auto end =
5316 NamedValue(val_range, "end", emitExpr(Expr(slice.end().get())));
5317 tuple_args.emplace_back(end);
5318 } else {
5319 tuple_args.emplace_back(std::nullopt);
5320 }
5321
5322 if (slice.step().present()) {
5323 auto step =
5324 NamedValue(val_range, "step", emitExpr(Expr(slice.step().get())));
5325 tuple_args.emplace_back(step);
5326 } else {
5327 tuple_args.emplace_back(std::nullopt);
5328 }
5329 auto tupleSliceValue =
5330 emitTupleSlice(val_range, s_tuple_val, tuple_args);
5331 return std::make_shared<SimpleValue>(tupleSliceValue);
5332 } else {
5333 return std::make_shared<SimpleValue>(emitBasicSlice(
5334 range, sv->asValue(val_range, method), subscript_exprs));
5335 }
5336 } else {
5337 AT_ASSERT(subscript_exprs.size() == 1);
5338 Value* sliceable = sv->asValue(val_range, method);
5339
5340 // In case of subscript expression being a Python Slice object.
5341 auto subscript_sv = emitSugaredExpr(subscript_exprs[0], 1);
5342 if (const auto slice_value =
5343 dynamic_cast<SliceValue*>(subscript_sv.get())) {
5344 Value* dim = nullptr;
5345 // aten::slice.tensor needs an additional `dim` input.
5346 if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
5347 dim = method.graph()->insertConstant(0, val_range);
5348 }
5349
5350 Value* sliced = emitSliceOp(
5351 val_range,
5352 sliceable,
5353 dim,
5354 slice_value->start(),
5355 slice_value->stop(),
5356 slice_value->step());
5357 return std::make_shared<SimpleValue>(sliced);
5358 }
5359
5360 // subscript is not a slice object, then it must be convertible to
5361 // a normal value.
5362 // Desugars gather syntactic sugar foo[i]
5363 Value* idx = subscript_sv->asValue(val_range, method);
5364 if (sliceable->type()->cast<TupleType>()) {
5365 return std::make_shared<SimpleValue>(
5366 emitTupleIndex(range, sv->asValue(val_range, method), idx));
5367 } else if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
5368 return std::make_shared<SimpleValue>(
5369 emitMultidimSlicing(range, sliceable, subscript_exprs));
5370 } else {
5371 return sv->getitem(range, method, idx, std::move(type_hint));
5372 }
5373 }
5374 }
5375 };
5376
5377 struct FunctionResolver : public Resolver {
FunctionResolvertorch::jit::FunctionResolver5378 explicit FunctionResolver(
5379 Resolver* otherResolver,
5380 const std::unordered_map<std::string, Function*>& functionTable)
5381 : otherResolver_(otherResolver), functionTable_(functionTable) {}
5382
resolveValuetorch::jit::FunctionResolver5383 std::shared_ptr<SugaredValue> resolveValue(
5384 const std::string& name,
5385 GraphFunction& m,
5386 const SourceRange& loc) override {
5387 auto it = functionTable_.find(name);
5388 if (it != functionTable_.end()) {
5389 return std::make_shared<FunctionValue>(it->second);
5390 }
5391 return otherResolver_->resolveValue(name, m, loc);
5392 }
5393
resolveTypetorch::jit::FunctionResolver5394 TypePtr resolveType(const std::string& name, const SourceRange& loc)
5395 override {
5396 return otherResolver_->resolveType(name, loc);
5397 }
5398
5399 private:
5400 Resolver* otherResolver_;
5401 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
5402 const std::unordered_map<std::string, Function*>& functionTable_;
5403 };
5404
CompilationUnit(const std::string & source)5405 CompilationUnit::CompilationUnit(const std::string& source)
5406 : CompilationUnit() {
5407 // calles the define with native resolver to generate the graph for functions
5408 define(std::nullopt, source, nativeResolver(), nullptr);
5409 }
5410
5411 // This pair represents a pair of functions (getter and setter) obtained from
5412 // compiling a Property.
5413 struct CompilationUnit::PropertyPair
5414 : public std::pair<std::unique_ptr<Function>, std::unique_ptr<Function>> {
PropertyPairtorch::jit::CompilationUnit::PropertyPair5415 PropertyPair(
5416 std::unique_ptr<Function> getter,
5417 std::unique_ptr<Function> setter) {
5418 TORCH_INTERNAL_ASSERT(getter, "Property pair must have defined getter")
5419 this->first = std::move(getter);
5420 this->second = std::move(setter);
5421 }
5422
getGettertorch::jit::CompilationUnit::PropertyPair5423 std::unique_ptr<Function>& getGetter() {
5424 return this->first;
5425 }
5426
getSettertorch::jit::CompilationUnit::PropertyPair5427 std::unique_ptr<Function>& getSetter() {
5428 return this->second;
5429 }
5430 };
5431
define_property(const std::optional<c10::QualifiedName> & prefix,const Property & prop,const ResolverPtr & resolver,const Self * self,const std::unordered_map<std::string,Function * > & function_table,bool shouldMangle) const5432 CompilationUnit::PropertyPair CompilationUnit::define_property(
5433 const std::optional<c10::QualifiedName>& prefix,
5434 const Property& prop,
5435 const ResolverPtr& resolver,
5436 const Self* self,
5437 const std::unordered_map<std::string, Function*>& function_table,
5438 bool shouldMangle) const {
5439 // self must be defined because properties are features of classes and
5440 // modules.
5441 TORCH_INTERNAL_ASSERT(self);
5442
5443 // Compile the getter function.
5444 std::unique_ptr<Function> getter_fn = define(
5445 prefix, prop.getter(), resolver, self, function_table, shouldMangle);
5446
5447 // Compile the setter function if it exists.
5448 std::unique_ptr<Function> setter_fn = nullptr;
5449 if (prop.setter().present()) {
5450 setter_fn = define(
5451 prefix,
5452 prop.setter().get(),
5453 resolver,
5454 self,
5455 function_table,
5456 shouldMangle);
5457 }
5458
5459 // Add the property to the class type definition.
5460 self->getClassType()->addProperty(
5461 prop.name().name(), getter_fn.get(), setter_fn.get());
5462
5463 return PropertyPair(std::move(getter_fn), std::move(setter_fn));
5464 }
5465
define(const std::optional<QualifiedName> & prefix,const Def & def,const ResolverPtr & resolver,const Self * self,const std::unordered_map<std::string,Function * > & function_table,bool shouldMangle,CompilationUnit::FunctionType type,std::optional<size_t> operator_set_version) const5466 std::unique_ptr<Function> CompilationUnit::define(
5467 const std::optional<QualifiedName>& prefix,
5468 const Def& def,
5469 const ResolverPtr& resolver,
5470 const Self* self,
5471 const std::unordered_map<std::string, Function*>& function_table,
5472 bool shouldMangle,
5473 CompilationUnit::FunctionType type,
5474 std::optional<size_t> operator_set_version) const {
5475 TORCH_INTERNAL_ASSERT(resolver);
5476 auto _resolver = resolver;
5477 if (!self) {
5478 // if self is defined, then these are methods and do not go into the
5479 // global namespace otherwise, they get defined together so we add them to
5480 // the function table so the methods can see each other
5481 _resolver =
5482 std::make_shared<FunctionResolver>(resolver.get(), function_table);
5483 }
5484 auto creator = [def, _resolver, self](GraphFunction& method) {
5485 // Store the function name so that it can be referenced if there is an error
5486 // while compiling this function
5487 std::string call_name = method.qualname().name();
5488 if (self) {
5489 auto atoms = method.qualname().atoms();
5490 // There should be at least a ClassName.method_name
5491 TORCH_INTERNAL_ASSERT(atoms.size() >= 2);
5492 call_name = atoms.at(atoms.size() - 2) + "." + atoms.at(atoms.size() - 1);
5493 }
5494 ErrorReport::CallStack call(call_name, def.range());
5495 to_ir(def, _resolver, self, method);
5496 };
5497 auto name = prefix ? QualifiedName(*prefix, def.name().name())
5498 : QualifiedName(def.name().name());
5499 if (shouldMangle) {
5500 // If `shouldMangle` is set, we should generate a unique name for this
5501 // function if there is already an existing one.
5502 if (find_function(name)) {
5503 name = mangle(name);
5504 }
5505 }
5506
5507 auto graph = std::make_shared<Graph>();
5508 graph->set_op_version(operator_set_version);
5509
5510 auto fn = std::make_unique<GraphFunction>(std::move(name), graph, creator);
5511 if (self) {
5512 // Register this as a method on `self`'s type
5513 if (type == CompilationUnit::FunctionType::Hook) {
5514 self->getClassType()->addForwardHook(fn.get());
5515 } else if (type == CompilationUnit::FunctionType::PreHook) {
5516 self->getClassType()->addForwardPreHook(fn.get());
5517 } else {
5518 self->getClassType()->addMethod(fn.get());
5519 }
5520 }
5521 return fn;
5522 }
5523
define(const std::optional<c10::QualifiedName> & prefix,const std::vector<Property> & properties,const std::vector<ResolverPtr> & propResolvers,const std::vector<Def> & definitions,const std::vector<ResolverPtr> & defResolvers,const Self * self,bool shouldMangle,std::optional<size_t> operator_set_version)5524 std::vector<Function*> CompilationUnit::define(
5525 const std::optional<c10::QualifiedName>& prefix,
5526 const std::vector<Property>& properties,
5527 const std::vector<ResolverPtr>& propResolvers,
5528 const std::vector<Def>& definitions,
5529 const std::vector<ResolverPtr>& defResolvers,
5530 const Self* self,
5531 bool shouldMangle,
5532 std::optional<size_t> operator_set_version) {
5533 TORCH_INTERNAL_ASSERT(definitions.size() == defResolvers.size());
5534 TORCH_INTERNAL_ASSERT(properties.size() == propResolvers.size());
5535 std::vector<Function*> functions;
5536 std::unordered_map<std::string, Function*> function_table;
5537
5538 // Records fn in function_table, functions and with register_function.
5539 // This is done several times below, so this lambda helps avoid repeating
5540 // code.
5541 auto record_function = [&](std::unique_ptr<Function> fn) {
5542 function_table[fn->name()] = fn.get();
5543 functions.emplace_back(fn.get());
5544 this->register_function(std::move(fn));
5545 };
5546
5547 for (const auto i : c10::irange(properties.size())) {
5548 PropertyPair property_fns = define_property(
5549 prefix,
5550 properties[i],
5551 propResolvers[i],
5552 self,
5553 function_table,
5554 shouldMangle);
5555
5556 auto& getter_fn = property_fns.getGetter();
5557 auto& setter_fn = property_fns.getSetter();
5558
5559 record_function(std::move(getter_fn));
5560
5561 if (setter_fn) {
5562 record_function(std::move(setter_fn));
5563 }
5564 }
5565
5566 for (const auto i : c10::irange(definitions.size())) {
5567 auto fn = define(
5568 prefix,
5569 definitions[i],
5570 defResolvers[i],
5571 self,
5572 function_table,
5573 shouldMangle,
5574 CompilationUnit::FunctionType::Method,
5575 operator_set_version);
5576
5577 record_function(std::move(fn));
5578 }
5579
5580 // We need to compile `__init__` first, since it can determine what attributes
5581 // are available to other methods. So reorder the definitions accordingly.
5582 for (auto& kv : function_table) {
5583 if (kv.first == "__init__") {
5584 kv.second->ensure_defined();
5585 }
5586 }
5587
5588 for (Function* function : functions) {
5589 function->ensure_defined();
5590 }
5591
5592 return functions;
5593 }
5594
define_hooks(const std::optional<c10::QualifiedName> & prefix,const std::vector<Def> & hookDefs,const std::vector<ResolverPtr> & hookResolvers,const std::vector<Def> & preHookDefs,const std::vector<ResolverPtr> & preHookResolvers,const Self * self,bool shouldMangle)5595 void CompilationUnit::define_hooks(
5596 const std::optional<c10::QualifiedName>& prefix,
5597 const std::vector<Def>& hookDefs,
5598 const std::vector<ResolverPtr>& hookResolvers,
5599 const std::vector<Def>& preHookDefs,
5600 const std::vector<ResolverPtr>& preHookResolvers,
5601 const Self* self,
5602 bool shouldMangle) {
5603 TORCH_INTERNAL_ASSERT(hookDefs.size() == hookResolvers.size());
5604 TORCH_INTERNAL_ASSERT(preHookDefs.size() == preHookResolvers.size());
5605 std::vector<Function*> functions;
5606 std::unordered_map<std::string, Function*> function_table;
5607
5608 // check hook for name collisions and redefinition
5609 auto check_collisions = [&](const Def& hook) -> Function* {
5610 auto name = prefix ? QualifiedName(*prefix, hook.name().name()).name()
5611 : QualifiedName(hook.name().name()).name();
5612 // check if hook is already defined for this module
5613 auto found_hook = function_table.find(name);
5614 auto existing_hook =
5615 found_hook != function_table.end() ? found_hook->second : nullptr;
5616 // check if hook name is already defined on module as method
5617 if (existing_hook == nullptr) {
5618 TORCH_CHECK(
5619 self->getClassType()->findMethod(name) == nullptr &&
5620 self->getClassType()->findHook(name) == nullptr,
5621 "Can't define hook: ",
5622 name,
5623 " on class: ",
5624 self->getClassType()->repr_str(),
5625 " because a method or hook with that name already exists.");
5626 }
5627 return existing_hook;
5628 };
5629
5630 // build_schema for checking
5631 auto build_schema = [&](const Def& hook_def,
5632 const ResolverPtr& hook_res) -> FunctionSchema {
5633 ScriptTypeParser typeParser(hook_res);
5634 FunctionSchema schema =
5635 typeParser.parseSchemaFromDef(hook_def, true /* skip_self*/);
5636 // need to add self as the first because we skipped it
5637 std::vector<Argument> arguments;
5638 arguments.emplace_back(
5639 hook_def.decl().params()[0].ident().name(), self->getClassType());
5640 arguments.insert(
5641 arguments.end(), schema.arguments().begin(), schema.arguments().end());
5642 return schema.cloneWithArguments(arguments);
5643 };
5644
5645 // define hooks
5646 for (const auto i : c10::irange(hookDefs.size())) {
5647 // check to see if already defined this hook
5648 auto existing_fn = check_collisions(hookDefs[i]);
5649 if (existing_fn != nullptr) {
5650 // add it to class type again so it's called
5651 self->getClassType()->addForwardHook(existing_fn);
5652 continue;
5653 }
5654 // define hook
5655 auto fn = define(
5656 prefix,
5657 hookDefs[i],
5658 hookResolvers[i],
5659 self,
5660 function_table,
5661 shouldMangle,
5662 CompilationUnit::FunctionType::Hook);
5663
5664 function_table[fn->name()] = fn.get();
5665 functions.emplace_back(fn.get());
5666 this->register_function(std::move(fn));
5667 self->getClassType()->checkForwardHookSchema(
5668 i, build_schema(hookDefs[i], hookResolvers[i]));
5669 functions.back()->ensure_defined();
5670 }
5671
5672 // define pre_hooks
5673 for (const auto i : c10::irange(preHookDefs.size())) {
5674 // check to see if already defined this hook
5675 auto existing_fn = check_collisions(preHookDefs[i]);
5676 if (existing_fn != nullptr) {
5677 // add it to class type again so it's called
5678 self->getClassType()->addForwardPreHook(existing_fn);
5679 continue;
5680 }
5681 // define pre_hook
5682 auto fn = define(
5683 prefix,
5684 preHookDefs[i],
5685 preHookResolvers[i],
5686 self,
5687 function_table,
5688 shouldMangle,
5689 CompilationUnit::FunctionType::PreHook);
5690
5691 function_table[fn->name()] = fn.get();
5692 functions.emplace_back(fn.get());
5693 this->register_function(std::move(fn));
5694 self->getClassType()->checkForwardPreHookSchema(
5695 i, build_schema(preHookDefs[i], preHookResolvers[i]));
5696 functions.back()->ensure_defined();
5697 }
5698 }
5699
define(const std::optional<QualifiedName> & prefix,const std::string & source,const ResolverPtr & resolver,const Self * self)5700 std::vector<Function*> CompilationUnit::define(
5701 const std::optional<QualifiedName>& prefix,
5702 const std::string& source,
5703 const ResolverPtr& resolver,
5704 const Self* self) {
5705 Parser p(std::make_shared<Source>(source, "<string>", 1));
5706 std::vector<Def> definitions;
5707 std::vector<ResolverPtr> resolvers;
5708 while (p.lexer().cur().kind != TK_EOF) {
5709 auto def = Def(p.parseFunction(/*is_method=*/bool(self)));
5710 definitions.push_back(def);
5711 resolvers.push_back(resolver);
5712 }
5713 return define(
5714 prefix,
5715 /*properties=*/{},
5716 /*propResolvers=*/{},
5717 definitions,
5718 resolvers,
5719 self);
5720 }
5721
eraseListLiterals(std::shared_ptr<Graph> & graph)5722 static void eraseListLiterals(std::shared_ptr<Graph>& graph) {
5723 DepthFirstGraphNodeIterator it(graph);
5724
5725 for (auto next_node = it.next(); next_node != nullptr;) {
5726 Node* node = next_node;
5727 next_node = it.next();
5728
5729 if (node->kind() == prim::EmptyListLiteral) {
5730 if (node->hasUses()) {
5731 TORCH_INTERNAL_ASSERT(
5732 node->output()->type()->isSubtypeOf(ListType::ofTensors()));
5733
5734 auto li = graph->createList(TensorType::get(), {});
5735 li->insertBefore(node);
5736 node->replaceAllUsesWith(li);
5737 }
5738 node->destroy();
5739 }
5740 }
5741 }
5742
runCleanupPasses(std::shared_ptr<Graph> & to_clean)5743 void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {
5744 liftClosures(to_clean);
5745 inlineForkedClosures(to_clean);
5746
5747 if (getInlineEverythingMode()) {
5748 Inline(*to_clean);
5749 }
5750
5751 // these exist temporarily in initial compilation
5752 eraseListLiterals(to_clean);
5753
5754 // remove any uses of tuples that we inserted that are not needed
5755 LowerSimpleTuples(to_clean);
5756
5757 // full constant propagation runs ops with mutable inputs if it can
5758 // prove that the inputs are not mutated anywhere in the graph.
5759 // if a mutating node is removed in the graph (e.g. constant prop inlined a
5760 // a constant if) then the next time constant prop is run it might be able
5761 // to run nodes it was not able to previously, and the graph may change
5762 // (jitter) So we run only constant prop w immutable types here bc
5763 // successive runs of immutable constant prop does not change the graph
5764 ConstantPropagationImmutableTypes(to_clean);
5765
5766 // Constant Pooling pass must be after ConstantPropagation, which can create
5767 // new constants that needs to be pooled.
5768 ConstantPooling(to_clean);
5769
5770 // For jitter
5771 CanonicalizeOutputs(to_clean);
5772
5773 // Annotate aten::warns so that each has its unique ID. This enables us to
5774 // mimic Python behavior of only emitting each warning only once.
5775 AnnotateWarns(to_clean);
5776 }
5777
5778 // we consider _N where N is a number, to be a non-meaningful name
5779 // and do not record it as a unique name. This allows python printing to
5780 // be able to export and import more consistently named graphs
meaningfulName(const std::string & name)5781 bool meaningfulName(const std::string& name) {
5782 if (name.empty())
5783 return false;
5784 if (name[0] == '$')
5785 return false;
5786 if (name[0] != '_')
5787 return true;
5788 for (const auto i : c10::irange(1, name.size())) {
5789 if (!isdigit(name[i]))
5790 return true;
5791 }
5792 return false;
5793 }
5794
define_interface(const c10::QualifiedName & qualifiedName,const ClassDef & classDef,ResolverPtr rcb,bool is_module)5795 void CompilationUnit::define_interface(
5796 const c10::QualifiedName& qualifiedName,
5797 const ClassDef& classDef,
5798 ResolverPtr rcb,
5799 bool is_module) {
5800 ScriptTypeParser typeParser(std::move(rcb));
5801 InterfaceTypePtr iface =
5802 InterfaceType::create(c10::QualifiedName(qualifiedName), is_module);
5803 for (const Stmt& stmt : classDef.body()) {
5804 if (stmt.kind() != TK_DEF) {
5805 throw(
5806 ErrorReport(stmt)
5807 << "interface declarations can only contain method definitions");
5808 }
5809 auto method_def = Def(stmt);
5810 if (!method_def.decl().return_type().present()) {
5811 throw(
5812 ErrorReport(method_def)
5813 << "interface declarations must have a return type annotated.");
5814 }
5815 FunctionSchema schema =
5816 typeParser.parseSchemaFromDef(method_def, /* skip_self*/ true);
5817 // need to add self as the first because we skipped it
5818 std::vector<Argument> arguments;
5819 arguments.emplace_back(method_def.decl().params()[0].ident().name(), iface);
5820 arguments.insert(
5821 arguments.end(), schema.arguments().begin(), schema.arguments().end());
5822 iface->addMethod(schema.cloneWithArguments(std::move(arguments)));
5823 // we need to make sure everything but the last element is just string
5824 // literals (aka comments) unless there is "pass" in between
5825 auto stmts_size = method_def.statements().size();
5826 for (size_t i = 0; i < stmts_size - 1; i++) {
5827 auto cur_statement = method_def.statements()[i];
5828 if (cur_statement.kind() == TK_EXPR_STMT) {
5829 auto expr = ExprStmt(cur_statement).expr();
5830 if (expr.kind() != TK_STRINGLITERAL) {
5831 throw(
5832 ErrorReport(method_def.range())
5833 << "interfaces declarations should only contain a single 'pass' statement.");
5834 }
5835 }
5836 // if we see a "pass", we just stop there
5837 if (cur_statement.kind() == TK_PASS) {
5838 this->register_type(iface);
5839 return;
5840 }
5841 }
5842
5843 if (method_def.statements()[stmts_size - 1].kind() != TK_PASS) {
5844 throw(
5845 ErrorReport(method_def.range())
5846 << "interfaces declarations should contain 'pass' statement.");
5847 }
5848 }
5849 this->register_type(iface);
5850 }
5851
5852 } // namespace torch::jit
5853