xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/schema_matching.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/schema_matching.h>
2 
3 #include <ATen/core/interned_strings.h>
4 #include <ATen/core/jit_type.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/irange.h>
7 #include <caffe2/serialize/versions.h>
8 #include <torch/csrc/jit/frontend/builtin_functions.h>
9 #include <torch/csrc/jit/frontend/error_report.h>
10 #include <torch/csrc/jit/frontend/function_schema_parser.h>
11 #include <torch/csrc/jit/ir/ir.h>
12 #include <torch/csrc/jit/operator_upgraders/utils.h>
13 #include <torch/csrc/jit/operator_upgraders/version_map.h>
14 #include <torch/csrc/jit/runtime/operator.h>
15 #include <optional>
16 
17 namespace torch::jit {
18 
unwrapOptional(TypePtr opt_type)19 static inline TypePtr unwrapOptional(TypePtr opt_type) {
20   if (auto dyn = opt_type->castRaw<c10::DynamicType>()) {
21     return unwrapOptional(dyn->fallback());
22   }
23   if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
24     return unwrap_list_type->getElementType();
25   }
26   return opt_type;
27 }
28 
isIntOrFloatUsedAsList(const Value * value,const Argument & arg)29 static inline bool isIntOrFloatUsedAsList(
30     const Value* value,
31     const Argument& arg) {
32   // Look for int[N] or float[N]
33   const auto& v_type = value->type();
34   if (v_type != FloatType::get() && v_type != IntType::get())
35     return false;
36   auto arg_type = unwrapOptional(arg.type());
37   auto list_type = arg_type->cast<ListType>();
38   return list_type && list_type->getElementType() == v_type && arg.N();
39 }
40 
41 /// Returns true if `type` is a Tuple in which all the elements have the
42 /// same type or if it's a subtype of `list_type_`.
convertibleToList(const TypePtr & type,const TypePtr & list_type_)43 bool convertibleToList(const TypePtr& type, const TypePtr& list_type_) {
44   auto list_type = list_type_->castRaw<ListType>();
45   if (!list_type) {
46     return false;
47   }
48   if (type->isSubtypeOf(*list_type_)) {
49     return true;
50   }
51   if (auto tuple = type->castRaw<TupleType>()) {
52     return std::all_of(
53         tuple->elements().begin(),
54         tuple->elements().end(),
55         [&](const TypePtr& t) {
56           // TODO: resolve VarType if necessary
57           return t->isSubtypeOf(*list_type->getElementType());
58         });
59   }
60   return false;
61 }
62 
63 // Applies implicit conversion from value trying to turn it into type
64 // concrete_type. It succeeds if `return_value->isSubtypeOf(concrete_type)`
tryConvertToType(const SourceRange & loc,Graph & graph,const TypePtr & concrete_type,Value * value,bool allow_conversions)65 Value* tryConvertToType(
66     const SourceRange& loc,
67     Graph& graph,
68     const TypePtr& concrete_type,
69     Value* value,
70     bool allow_conversions) {
71   // treat conversion to Optional[T] as conversions to T
72   if (OptionalTypePtr op = concrete_type->cast<OptionalType>()) {
73     if (value->type()->kind() != OptionalType::Kind &&
74         !value->type()->isSubtypeOf(*NoneType::get())) {
75       return tryConvertToType(
76           loc, graph, op->getElementType(), value, allow_conversions);
77     }
78   }
79 
80   // allow temporary, unannotated list literals `[]` to match to arbitrary list
81   // types
82   if (value->node()->kind() == prim::EmptyListLiteral &&
83       concrete_type->cast<ListType>()) {
84     value = graph
85                 .insertNode(graph.createList(
86                     concrete_type->cast<ListType>()->getElementType(), {}))
87                 ->output();
88   }
89 
90   if (auto value_tuple = value->type()->cast<TupleType>()) {
91     // Allow homogeneous tuples to be casted implicitly to lists of appropriate
92     // types
93     if (convertibleToList(value->type(), unwrapOptional(concrete_type))) {
94       auto unpacked = createTupleUnpack(value);
95       auto elem_type =
96           unwrapOptional(concrete_type)->expectRef<ListType>().getElementType();
97       value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
98     }
99 
100     // inductively apply implicit conversions to tuples
101     if (auto concrete_tuple = concrete_type->cast<TupleType>()) {
102       if (!value_tuple->isSubtypeOf(*concrete_tuple) &&
103           concrete_tuple->elements().size() == value_tuple->elements().size()) {
104         auto unpacked = createTupleUnpack(value);
105         std::vector<Value*> converted;
106         for (size_t i = 0; i < concrete_tuple->elements().size(); ++i) {
107           converted.emplace_back(tryConvertToType(
108               loc,
109               graph,
110               concrete_tuple->elements().at(i),
111               unpacked.at(i),
112               allow_conversions));
113         }
114         value = graph.insertNode(graph.createTuple(converted))->output();
115       }
116     }
117   }
118 
119   // implicit conversions
120   if (allow_conversions) {
121     // Convert tensor or number to concrete int/float types
122     bool value_isa_tensor = value->type()->isSubtypeOf(*TensorType::get());
123     bool value_equals_number = *value->type() == *NumberType::get();
124     bool concrete_float = *concrete_type == *FloatType::get();
125     bool concrete_complex = *concrete_type == *ComplexType::get();
126     bool concrete_int = *concrete_type == *IntType::get();
127     bool concrete_number = *concrete_type == *NumberType::get();
128     if (value_isa_tensor) {
129       if (concrete_float) {
130         value = graph.insert(aten::FloatImplicit, {value}, {}, loc);
131       } else if (concrete_complex) {
132         value = graph.insert(aten::ComplexImplicit, {value}, {}, loc);
133       } else if (concrete_int) {
134         value = graph.insert(aten::IntImplicit, {value}, {}, loc);
135       } else if (concrete_number) {
136         value = graph.insert(aten::ScalarImplicit, {value}, {}, loc);
137       }
138     } else if (value_equals_number) {
139       if (concrete_float) {
140         value = graph.insert(aten::Float, {value}, {}, loc);
141       } else if (concrete_complex) {
142         value = graph.insert(aten::Complex, {value}, {}, loc);
143       } else if (concrete_int) {
144         value = graph.insert(aten::Int, {value}, {}, loc);
145       }
146     } else if (*value->type() == *BoolType::get()) {
147       if (concrete_float) {
148         value = graph.insert(aten::Float, {value}, {}, loc);
149       } else if (concrete_int || concrete_number) {
150         value = graph.insert(aten::Int, {value}, {}, loc);
151       }
152     }
153 
154     // Convert strings to device
155     if (value->type()->isSubtypeOf(*StringType::get()) &&
156         concrete_type->isSubtypeOf(*DeviceObjType::get())) {
157       return graph.insert(aten::device, {value}, {}, loc);
158     }
159   }
160 
161   return value;
162 }
163 
164 // Checks if `named_value` can be used as a value for `arg`. If `arg` is a
165 // VarType, it will be added to the type_env through `matchTypeVariables` as
166 // the corresponding actual type. If `allow_conversions` is true, implicit
167 // conversions to the `arg` type may be performed through `tryConvertToType`.
tryMatchArgument(const Argument & arg,Graph & graph,const SourceRange & loc,const NamedValue & named_value,std::ostream * failure_messages,const std::function<std::ostream & ()> & err,bool allow_conversions,TypeEnv & type_env)168 static Value* tryMatchArgument(
169     const Argument& arg,
170     Graph& graph,
171     const SourceRange& loc,
172     const NamedValue& named_value,
173     std::ostream* failure_messages,
174     const std::function<std::ostream&()>& err,
175     bool allow_conversions,
176     TypeEnv& type_env) {
177   Value* value = named_value.value(graph);
178 
179   // Some functions that take lists of integers or floats for fixed size arrays
180   // also allow single ints/floats to be passed in their place. The single
181   // int/float is then repeated to the length of the list
182   if (isIntOrFloatUsedAsList(value, arg)) {
183     std::vector<Value*> repeated(*arg.N(), value);
184     value =
185         graph.insertNode(graph.createList(value->type(), repeated))->output();
186   }
187 
188   // Resolve VarType variables
189   const MatchTypeReturn matched =
190       matchTypeVariables(arg.type(), value->type(), type_env);
191   if (!matched.success()) {
192     if (failure_messages) {
193       err() << "Could not match type " << value->type()->repr_str() << " to "
194             << arg.type()->repr_str() << " in argument '" << arg.name()
195             << "': " << matched.reason() << ".\n";
196     }
197     return nullptr;
198   }
199   const auto concrete_type = tryEvalTypeVariables(arg.type(), type_env);
200   if (!concrete_type) {
201     if (failure_messages) {
202       err() << "Type variables in type " << arg.type()->repr_str()
203             << " could not be inferred from actual type "
204             << value->type()->repr_str();
205     }
206     return nullptr;
207   }
208 
209   // Check if the value can be matched to the arg through any implicit
210   // conversions
211   value = tryConvertToType(loc, graph, concrete_type, value, allow_conversions);
212   std::stringstream ss;
213   if (!value->type()->isSubtypeOfExt(
214           *concrete_type, /*why_not=*/(failure_messages) ? &ss : nullptr)) {
215     if (failure_messages) {
216       auto& ostream = err()
217           << arg.formatTypeMismatchMsg(value->type()->repr_str());
218 
219       if (auto pt = value->type()->cast<TensorType>()) {
220         if (pt->isInferredType()) {
221           std::string inferred_type_hint;
222           inferred_type_hint = c10::str(
223               "Inferred the value for argument '",
224               arg.name(),
225               "' to be of type 'Tensor' ",
226               "because it was not annotated with an explicit type.\n");
227           ostream << inferred_type_hint;
228         }
229       }
230 
231       if (auto v = value->type()->cast<ListType>()) {
232         if (v->getElementType()->isSubtypeOf(*TensorType::get())) {
233           ostream << "Empty lists default to List[Tensor]. Add a variable "
234                      "annotation to the assignment to create an empty list "
235                      "of another type (torch.jit.annotate(List[T, []]) where T "
236                      "is the type of elements in the list for Python 2)\n";
237         }
238       }
239 
240       ostream << ss.str();
241     }
242 
243     return nullptr;
244   }
245   return value;
246 }
247 
findInputWithName(const std::string & name,at::ArrayRef<NamedValue> kwargs,bool is_aten)248 std::optional<size_t> findInputWithName(
249     const std::string& name,
250     at::ArrayRef<NamedValue> kwargs,
251     bool is_aten) {
252   for (const auto i : c10::irange(kwargs.size())) {
253     // TS doesn't understand that the self argument in function
254     // scheams is renamed to input for the functional variant
255     if (is_aten && name == "self" && kwargs[i].name() == "input") {
256       return i;
257     }
258     if (kwargs[i].name() == name) {
259       return i;
260     }
261   }
262   return std::nullopt;
263 }
264 
265 /// Creates a list with the provided values if each value's type can be matched
266 /// to an argument with type `elem_type`. If a type in `varargs` does not match
267 /// `elem_type`, nullptr is returned. This is used for creating lists from
268 /// varargs so that calls like torch.zeros(1, 2, 3) will be matched to
269 /// aten::zeros(int[]).
tryCreateList(const TypePtr & elem_type,Graph & graph,const SourceRange & loc,at::ArrayRef<NamedValue> varargs,std::ostream * failure_messages,const std::function<std::ostream & ()> & err,bool convert_tensor_to_num,TypeEnv & type_env)270 static Value* tryCreateList(
271     const TypePtr& elem_type,
272     Graph& graph,
273     const SourceRange& loc,
274     at::ArrayRef<NamedValue> varargs,
275     std::ostream* failure_messages,
276     const std::function<std::ostream&()>& err,
277     bool convert_tensor_to_num,
278     TypeEnv& type_env) {
279   Argument elem_arg("<varargs>", elem_type);
280   std::vector<Value*> list_elements;
281   for (const auto& named_value : varargs) {
282     // Try to convert named_value to elem_type
283     Value* matched_value = tryMatchArgument(
284         /*arg=*/elem_arg,
285         graph,
286         loc,
287         named_value,
288         failure_messages,
289         err,
290         /*allow_conversions=*/convert_tensor_to_num,
291         type_env);
292     if (!matched_value) {
293       return nullptr;
294     }
295     list_elements.push_back(matched_value);
296   }
297 
298   return graph.insertNode(graph.createList(elem_type, list_elements))->output();
299 }
300 
301 // Check if it is possible to convert all the remaining non-kwarg arguments
302 // to a list. This allows zeros(IntArrayRef sizes) to work with zeros(1, 2) or
303 // zeros(1)
varargsCanBeUsedAsList(const FunctionSchema & schema,size_t arg_index,const Argument & arg)304 static bool varargsCanBeUsedAsList(
305     const FunctionSchema& schema,
306     size_t arg_index,
307     const Argument& arg) {
308   // The arg must be the last one in the arg list that is not a kwarg
309   bool is_last_argument = arg_index + 1 == schema.arguments().size() ||
310       schema.arguments()[arg_index + 1].kwarg_only();
311 
312   auto arg_type = arg.type();
313   if (auto dyn = arg_type->castRaw<c10::DynamicType>()) {
314     arg_type = dyn->fallback();
315   }
316 
317   // The formal must be a list
318   bool argument_is_list = arg_type->kind() == TypeKind::ListType;
319 
320   // matching varargs of typevar list nyi
321   bool typevar_list = argument_is_list &&
322       arg_type->castRaw<ListType>()->getElementType()->cast<VarType>();
323 
324   // it must not be a broadcasting list like int[3],
325   // otherwise a single int is a valid input
326   bool arg_is_broadcasting_list = bool(arg.N());
327 
328   return is_last_argument && argument_is_list && !arg_is_broadcasting_list &&
329       !typevar_list;
330 }
331 
isBlockListedSchema(const FunctionSchema & schema)332 bool isBlockListedSchema(const FunctionSchema& schema) {
333   // Note (@zasdfgbnm):
334   // This is a workaround for https://github.com/pytorch/pytorch/issues/47964
335   // Currently JIT does not distinguish ScalarType vs int, so there is really
336   // no way to distinguish x.view(1) vs x.view(torch.int8). So we have to
337   // hardcode the aten::view.dtype here to block this overload. This blocklist
338   // should be removed when JIT fully suports ScalarType as its own type.
339   if (schema.name() == "aten::view" && schema.overload_name() == "dtype") {
340     return true;
341   }
342   // Note (@tugsbayasgalan)
343   // TorchScript doesn't suport kwargs so this op collides with aten.max.others
344   // since both of them have 2 Tensor inputs. Since we don't expect users to
345   // use this op in TS, we just skip it
346   if (schema.name() == "aten::max" && schema.overload_name() == "unary_out") {
347     return true;
348   }
349   if (schema.name() == "aten::min" && schema.overload_name() == "unary_out") {
350     return true;
351   }
352   return false;
353 }
354 
tryMatchSchema(const FunctionSchema & schema,const SourceRange & loc,Graph & graph,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,std::optional<NamedValue> self,std::ostream * failure_messages,bool allow_conversions)355 static std::optional<MatchedSchema> tryMatchSchema(
356     const FunctionSchema& schema,
357     const SourceRange& loc,
358     Graph& graph,
359     at::ArrayRef<NamedValue> args,
360     at::ArrayRef<NamedValue> kwargs,
361     std::optional<NamedValue> self,
362     std::ostream* failure_messages,
363     bool allow_conversions) {
364   if (isBlockListedSchema(schema)) {
365     return std::nullopt;
366   }
367 
368   auto err = [&]() -> std::ostream& {
369     *failure_messages << "\n" << schema << ":\n";
370     return *failure_messages;
371   };
372 
373   // For VarTypes, maps VarType name to actual type as it's used with these
374   // args
375   TypeEnv type_env;
376   std::vector<Value*> positional_inputs;
377   std::vector<bool> used_kwarg(kwargs.size(), false);
378 
379   auto schema_namespace = schema.operator_name().getNamespace();
380   bool is_aten = false;
381   if (schema_namespace.has_value()) {
382     if (schema_namespace.value() == "aten") {
383       is_aten = true;
384     }
385   }
386   // if we finish the loop will we have consumed all arguments?
387   size_t used_args = 0;
388   for (const auto schema_i : c10::irange(schema.arguments().size())) {
389     const auto& arg = schema.arguments()[schema_i];
390     std::optional<NamedValue> actual_named_value;
391     if (arg.name() == "self" && self) {
392       actual_named_value = self;
393       self = std::nullopt;
394     } else if (!arg.kwarg_only() && used_args < args.size()) {
395       // Try to convert all the remaining non-kwarg arguments (used_args) to a
396       // list. Allow zeros(IntArrayRef sizes) to work with zeros(1, 2) or
397       // zeros(1)
398       if (allow_conversions && varargsCanBeUsedAsList(schema, schema_i, arg)) {
399         auto value = args[used_args].value(graph);
400         const auto& actual_type = value->type();
401         // The actual cannot already be a list
402         if (actual_type->kind() != TypeKind::ListType &&
403             !convertibleToList(actual_type, unwrapOptional(arg.type()))) {
404           auto formal_type = unwrapOptional(arg.type())
405                                  ->expectRef<ListType>()
406                                  .getElementType();
407 
408           Value* list = tryCreateList(
409               formal_type,
410               graph,
411               loc,
412               at::ArrayRef<NamedValue>(args).slice(used_args),
413               failure_messages,
414               err,
415               allow_conversions,
416               type_env);
417           if (!list) {
418             return std::nullopt;
419           }
420           used_args = args.size();
421           positional_inputs.push_back(list);
422           continue;
423         }
424       }
425 
426       // Set actual_named_value to the argument and mark the arg position as
427       // used
428       actual_named_value = args[used_args];
429       used_args++;
430     } else if (
431         auto kwarg_idx = findInputWithName(arg.name(), kwargs, is_aten)) {
432       const NamedValue& nv = kwargs[*kwarg_idx];
433       if (used_kwarg[*kwarg_idx]) {
434         if (failure_messages) {
435           err() << "Argument " << nv.name()
436                 << " specified twice in schema, submit a bug report!\n";
437         }
438         return std::nullopt;
439       }
440       used_kwarg[*kwarg_idx] = true;
441       actual_named_value = nv;
442     } else if (arg.default_value()) {
443       // Argument has a default value and no value was provided, so use the
444       // default
445       actual_named_value = NamedValue(*arg.default_value());
446     } else {
447       if (failure_messages) {
448         err() << "Argument " << schema.arguments()[schema_i].name()
449               << " not provided.\n";
450       }
451       return std::nullopt;
452     }
453 
454     // Make sure the actual_named_value found matches the type of arg
455     Value* positional = tryMatchArgument(
456         arg,
457         graph,
458         loc,
459         *actual_named_value,
460         failure_messages,
461         err,
462         allow_conversions,
463         type_env);
464     if (!positional) {
465       return std::nullopt;
466     }
467     positional_inputs.push_back(positional);
468   }
469   // check for unused self argument
470   if (self != std::nullopt) {
471     if (failure_messages) {
472       err() << "Provided self argument not used in schema.\n";
473     }
474     return std::nullopt;
475   }
476 
477   if (schema.is_vararg()) {
478     for (; used_args < args.size(); ++used_args) {
479       positional_inputs.push_back(args[used_args].value(graph));
480     }
481   }
482 
483   // check for unused positional arguments
484   if (used_args < args.size()) {
485     if (failure_messages) {
486       err() << "Expected at most " << used_args << " arguments "
487             << "but found " << args.size() << " positional arguments.\n";
488     }
489     return std::nullopt;
490   }
491   // check for unused kwargs
492   for (const auto i : c10::irange(kwargs.size())) {
493     const auto& nv = kwargs[i];
494     if (!used_kwarg[i]) {
495       if (failure_messages) {
496         if (!schema.argumentIndexWithName(nv.name())) {
497           err() << "Keyword argument " << nv.name() << " unknown.\n";
498         } else {
499           err() << "Keyword argument " << nv.name() << " specified twice.\n";
500         }
501       }
502       return std::nullopt;
503     }
504   }
505 
506   const auto& returns = schema.returns();
507   auto return_types = fmap(returns, [&](const Argument& r) {
508     TypePtr result = tryEvalTypeVariables(r.type(), type_env);
509     TORCH_INTERNAL_ASSERT(
510         result, r.type()->repr_str(), " has unbound type variables.");
511     return result;
512   });
513   // Codegen does not support return of namedtuples with undefined field names.
514   // Therefore, either all or none returns has field names.
515   bool return_has_field_names =
516       std::all_of(returns.begin(), returns.end(), [&](const Argument& r) {
517         return !r.name().empty();
518       });
519   c10::OptNameList return_field_names = std::nullopt;
520   if (return_has_field_names) {
521     return_field_names =
522         fmap(returns, [&](const Argument& r) { return r.name(); });
523   }
524 
525   // construct the full name of the schema for easier look up
526   auto schema_name = getFullSchemaName(schema);
527 
528   return MatchedSchema{
529       std::move(positional_inputs),
530       std::move(return_types),
531       std::move(return_field_names),
532       schema_name};
533 }
534 
matchSchema(const::c10::FunctionSchema & schema,const SourceRange & loc,Graph & graph,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,const std::optional<NamedValue> & self)535 MatchedSchema matchSchema(
536     const ::c10::FunctionSchema& schema,
537     const SourceRange& loc,
538     Graph& graph,
539     at::ArrayRef<NamedValue> args,
540     at::ArrayRef<NamedValue> kwargs,
541     const std::optional<NamedValue>& self) {
542   std::stringstream failure_messages;
543   if (auto result = tryMatchSchema(
544           schema,
545           loc,
546           graph,
547           args,
548           kwargs,
549           self,
550           &failure_messages,
551           /*allow_conversions=*/true)) {
552     return *result;
553   }
554   throw(ErrorReport(loc) << failure_messages.str());
555 }
556 
prefixLine(const std::string & str,const std::string & prefix)557 static std::string prefixLine(
558     const std::string& str,
559     const std::string& prefix) {
560   std::stringstream ss;
561   bool was_newline = true;
562   for (auto c : str) {
563     if (was_newline)
564       ss << prefix;
565     ss.put(c);
566     was_newline = c == '\n';
567   }
568   return ss.str();
569 }
570 
matchSchemas(const std::vector<const FunctionSchema * > & schemas,const SourceRange & loc,Graph & graph,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,const std::optional<NamedValue> & self,bool render_errors)571 std::pair<size_t, MatchedSchema> matchSchemas(
572     const std::vector<const FunctionSchema*>& schemas,
573     const SourceRange& loc,
574     Graph& graph,
575     at::ArrayRef<NamedValue> args,
576     at::ArrayRef<NamedValue> kwargs,
577     const std::optional<NamedValue>& self,
578     bool render_errors) {
579   TORCH_INTERNAL_ASSERT(!schemas.empty());
580   // if there is only one schema, we do not need to try without conversions
581   // first. this is faster and puts less dead code in the graph.
582   if (schemas.size() == 1) {
583     return std::make_pair(
584         0, matchSchema(*schemas.at(0), loc, graph, args, kwargs, self));
585   }
586   std::stringstream failure_messages;
587   for (bool allow_conversions : {false, true}) {
588     // clear previous error messages
589     failure_messages.str("");
590     for (const auto i : c10::irange(schemas.size())) {
591       const auto matched_schema = tryMatchSchema(
592           *schemas[i],
593           loc,
594           graph,
595           args,
596           kwargs,
597           self,
598           render_errors ? &failure_messages : nullptr,
599           allow_conversions);
600       if (matched_schema) {
601         return std::make_pair(i, *matched_schema);
602       }
603     }
604   }
605   // we optimistically assume this call will not error, and avoid formatting the
606   // error strings. If we discover it did error, then we replay it, recording
607   // the errors.
608   if (!render_errors) {
609     return matchSchemas(
610         schemas, loc, graph, args, kwargs, self, /*render_errors=*/true);
611   }
612 
613   throw(
614       ErrorReport(loc) << "Arguments for call are not valid.\n"
615                        << "The following variants are available:\n"
616                        << prefixLine(failure_messages.str(), "  ")
617                        << "\nThe original call is");
618   throw(ErrorReport(loc) << failure_messages.str());
619 }
620 
621 // pack outputs of a function following python rules. If there is a single value
622 // return a SimpleValue, otherwise pack all the values into a Tuple.
packOutputs(Graph & g,at::ArrayRef<Value * > values,c10::OptNameList field_names)623 static Value* packOutputs(
624     Graph& g,
625     at::ArrayRef<Value*> values,
626     c10::OptNameList field_names) {
627   if (values.size() == 1) {
628     return values[0];
629   }
630   std::shared_ptr<FunctionSchema> schema;
631   TupleTypePtr named_tuple = nullptr;
632   if (field_names) {
633     auto types = fmap(values, [](Value* v) { return v->type(); });
634     named_tuple =
635         TupleType::createNamed(std::nullopt, field_names.value(), types);
636   }
637   return g.insertNode(g.createTuple(values, named_tuple))->output();
638 }
639 
640 // Given a successful match between operator schema and symbol, emit a node
641 // with the appropriate inputs and outputs.
emitBuiltinNode(const MatchedSchema & matched_schema,const SourceRange & loc,Graph & graph,Symbol name,std::optional<size_t> version)642 static Value* emitBuiltinNode(
643     const MatchedSchema& matched_schema,
644     const SourceRange& loc,
645     Graph& graph,
646     Symbol name,
647     std::optional<size_t> version) {
648   auto n = graph.insertNode(graph.create(name, matched_schema.inputs, 0))
649                ->setSourceRange(loc);
650 
651   for (auto& ret : matched_schema.return_types) {
652     n->addOutput()->setType(ret);
653   }
654 
655   // assert that we did indeed create an op that has implementation
656   // otherwise schema and dispatch are not in sync ONLY if the op is up
657   // to date with the server version
658   if (!version.has_value() ||
659       isOpSymbolCurrent(matched_schema.schema_name, version.value())) {
660     n->getOperation();
661   } else {
662     n->setHistoricSchemaName(matched_schema.schema_name);
663   }
664 
665   return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
666 }
667 
getFullSchemaName(const::c10::FunctionSchema & schema)668 std::string getFullSchemaName(const ::c10::FunctionSchema& schema) {
669   if (!schema.overload_name().empty()) {
670     return schema.operator_name().name + "." + schema.overload_name();
671   }
672   return schema.operator_name().name;
673 }
674 
675 // Search for operators matching the provided symbol name and input types.
676 // If one is found, emit a node to the graph for that operator.
emitBuiltinCall(const SourceRange & loc,Graph & graph,Symbol name,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,const std::optional<NamedValue> & self)677 Value* emitBuiltinCall(
678     const SourceRange& loc,
679     Graph& graph,
680     Symbol name,
681     at::ArrayRef<NamedValue> args,
682     at::ArrayRef<NamedValue> kwargs,
683     const std::optional<NamedValue>& self) {
684   const auto& variants = getAllOperatorsFor(name);
685   const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
686 
687   // first let's set the graph's version
688   auto graph_version = graph.get_op_version();
689 
690   std::vector<const FunctionSchema*> schemas;
691   // we append them later to schemas because
692   // parseSchema returns rvalue which can not
693   // be casted to const pointer.
694   std::vector<FunctionSchema> upgrader_schemas;
695   schemas.reserve(variants.size());
696   for (const std::shared_ptr<Operator>& op : variants) {
697     bool found_upgrader = false;
698     auto op_name = getFullSchemaName(op->schema());
699     if (graph_version.has_value()) {
700       auto version_entry = get_operator_version_map().find(op_name);
701       if (version_entry != get_operator_version_map().end()) {
702         auto old_schema_entry =
703             findUpgrader(version_entry->second, graph_version.value());
704         if (old_schema_entry.has_value()) {
705           FunctionSchema old_schema =
706               parseSchema(old_schema_entry.value().old_schema);
707           upgrader_schemas.push_back(old_schema);
708           found_upgrader = true;
709         } else {
710           if (!isOpCurrentBasedOnUpgraderEntries(
711                   version_entry->second, graph_version.value())) {
712             TORCH_INTERNAL_ASSERT(false, "Valid upgrader must be present");
713           }
714         }
715       }
716     }
717     if (!found_upgrader)
718       schemas.push_back(&op->schema());
719   }
720 
721   // we might have seen old historic
722   // ops that are deprecated
723   if (variants.empty()) {
724     auto oldSchemas =
725         loadPossibleHistoricOps(name.toQualString(), graph_version);
726     upgrader_schemas.reserve(oldSchemas.size());
727     for (const auto& old_schema_entry : oldSchemas) {
728       FunctionSchema old_schema = parseSchema(old_schema_entry);
729       upgrader_schemas.emplace_back(old_schema);
730     }
731   }
732 
733   // TODO (tugsuu): make sure this is optimized later
734   for (const auto& schema : upgrader_schemas) {
735     schemas.push_back(&schema);
736   }
737 
738   for (const auto method : builtin_functions) {
739     method->ensure_defined();
740     schemas.push_back(&method->getSchema());
741   }
742 
743   // no operators found with the same name, print out similarly named operators
744   if (schemas.empty()) {
745     const auto close_symbols = findSimilarOperators(name);
746     auto error = ErrorReport(loc);
747     const auto& user_function_name = name.toQualString();
748     error << "Unknown builtin op: " << user_function_name << ".\n";
749     if (close_symbols.empty()) {
750       error
751           << "Could not find any similar ops to " << user_function_name
752           << ". This op may not exist or may not be currently supported in TorchScript.\n";
753     } else {
754       error << "Here are some suggestions: \n";
755       for (const auto& sym : close_symbols) {
756         error << "\t" << sym.toQualString() << "\n";
757       }
758       error << "\nThe original call is";
759     }
760     throw ErrorReport(error);
761   }
762 
763   auto matched = matchSchemas(schemas, loc, graph, args, kwargs, self);
764 
765   if (matched.first < variants.size() + upgrader_schemas.size()) {
766     return emitBuiltinNode(matched.second, loc, graph, name, graph_version);
767   } else {
768     auto& fn = *builtin_functions[matched.first - variants.size()];
769     // we inline builtin calls because they are normally very small
770     // wrappers and are not useful for keeping around to debug
771     return insertGraph(
772                graph, *toGraphFunction(fn).graph(), matched.second.inputs)
773         .at(0);
774   }
775 }
776 
777 } // namespace torch::jit
778