xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/function_schema.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/function_schema.h>
2 
3 #include <iostream>
4 #include <stack>
5 #include <utility>
6 
7 namespace c10 {
8 
dump() const9 void FunctionSchema::dump() const {
10   std::cout << *this << "\n";
11 }
12 
getCorrectList(SchemaArgType type) const13 const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type) const {
14   if (type == SchemaArgType::input) {
15     return arguments();
16   } else {
17     return returns();
18   }
19 }
20 
cloneWithRealTypes(bool with_symint) const21 FunctionSchema FunctionSchema::cloneWithRealTypes(bool with_symint) const {
22   auto alwaysCloneWithRealTypes = [&](const Argument& a) {
23     return a.cloneWithType(a.real_type());
24   };
25   auto cloneWithRealTypes = [&](const Argument& a) {
26     if (with_symint) {
27       return a.cloneWithType(a.real_type());
28     }
29     // Don't use real type if it looks like a SymInt
30     // NB: keep this in sync with unpackSymInt in KernelFunction_impl.h
31     if (
32       *a.real_type() == *getTypePtr<c10::SymInt>() ||
33       *a.real_type() == *getTypePtr<std::optional<c10::SymInt>>() ||
34       *a.real_type() == *getTypePtr<c10::SymIntArrayRef>() ||
35       *a.real_type() == *getTypePtr<at::OptionalSymIntArrayRef>()
36     ) {
37       // Keep the fake type
38       return a.cloneWithType(a.type());
39     } else {
40       return a.cloneWithType(a.real_type());
41     }
42   };
43   std::vector<Argument> new_arguments, new_returns;
44   std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes);
45   // NB: SymInt returns are always SymInt
46   std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), alwaysCloneWithRealTypes);
47   return FunctionSchema(
48     name(),
49     overload_name(),
50     std::move(new_arguments),
51     std::move(new_returns),
52     is_vararg(),
53     is_varret());
54 }
55 
canAliasTypeSetsAlias(const std::optional<AliasTypeSet> & lhs,const std::optional<AliasTypeSet> & rhs) const56 bool FunctionSchema::canAliasTypeSetsAlias(const std::optional<AliasTypeSet> &lhs, const std::optional<AliasTypeSet> &rhs) const {
57   if (!lhs || !rhs) {
58     return false;
59   }
60   for (const TypePtr& lhsType : *lhs) {
61     for (const TypePtr& rhsType : *rhs) {
62       if (lhsType == rhsType) {
63         return true;
64       }
65     }
66   }
67   return false;
68 }
69 
getAliasTypeSetContainedTypes(const std::optional<AliasTypeSet> & aliasTypeSet) const70 std::optional<AliasTypeSet> FunctionSchema::getAliasTypeSetContainedTypes(const std::optional<AliasTypeSet> &aliasTypeSet) const {
71   if (!aliasTypeSet) {
72     return std::nullopt;
73   }
74   std::unordered_set<TypePtr> containedTypes;
75   std::stack<TypePtr> typeStack;
76   // Push all 1st level contained types into the stack.
77   for (const TypePtr& type: *aliasTypeSet) {
78     for (const TypePtr& containedType : type->containedTypes()){
79       typeStack.push(containedType);
80     }
81   }
82 
83   // process all further level contained types.
84   while (!typeStack.empty()) {
85     TypePtr current = typeStack.top();
86     typeStack.pop();
87     if (!containedTypes.count(current)) {
88       for (const TypePtr& containedType : current->containedTypes()) {
89         typeStack.push(containedType);
90       }
91     }
92     containedTypes.insert(current);
93   }
94 
95   return AliasTypeSet(containedTypes.begin(), containedTypes.end());
96 }
97 
mapTypeToAliasTypeSet(const TypePtr & type) const98 std::optional<AliasTypeSet> FunctionSchema::mapTypeToAliasTypeSet(const TypePtr& type) const {
99   switch(type->kind()) {
100     case TypeKind::ListType:
101     case TypeKind::DictType:
102     case TypeKind::ClassType:
103     case TypeKind::TensorType:
104       return AliasTypeSet {c10::unshapedType(type)};
105     case TypeKind::UnionType: {
106       AliasTypeSet mutable_types;
107       for (const TypePtr& inner :
108             type->expectRef<UnionType>().containedTypes()) {
109         if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) {
110           mutable_types.insert(
111               mutable_types.end(),
112               (*maybe_inner_types).begin(),
113               (*maybe_inner_types).end());
114         }
115       }
116       if (mutable_types.empty()) {
117         return std::nullopt;
118       }
119       return mutable_types;
120     }
121     case TypeKind::AnyType:
122       return {AliasTypeSet{type}};
123     case TypeKind::OptionalType: {
124       auto inner = type->castRaw<OptionalType>()->getElementType();
125       return mapTypeToAliasTypeSet(inner);
126     }
127     case TypeKind::TupleType: {
128       AliasTypeSet mutable_types;
129       for (const TypePtr& inner : type->expectRef<TupleType>().elements()) {
130         if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) {
131           mutable_types.insert(
132               mutable_types.end(),
133               (*maybe_inner_types).begin(),
134               (*maybe_inner_types).end());
135         }
136       }
137       if (mutable_types.empty()) {
138         return std::nullopt;
139       }
140       return {AliasTypeSet{TupleType::create(std::move(mutable_types))}};
141     }
142     default:
143       return std::nullopt;
144   }
145 }
146 
may_alias(const SchemaArgument & lhs,const SchemaArgument & rhs) const147 bool FunctionSchema::may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const {
148   TORCH_INTERNAL_ASSERT(
149       (lhs.index < getCorrectList(lhs.type).size()),
150       "Invalid index for schema.");
151   TORCH_INTERNAL_ASSERT(
152       (rhs.index < getCorrectList(rhs.type).size()),
153       "Invalid index for schema.");
154 
155   const Argument lhsArg = getCorrectList(lhs.type)[lhs.index];
156   const Argument rhsArg = getCorrectList(rhs.type)[rhs.index];
157 
158   std::optional<AliasTypeSet> lhsTypes = mapTypeToAliasTypeSet(lhsArg.type());
159   std::optional<AliasTypeSet> rhsTypes = mapTypeToAliasTypeSet(rhsArg.type());
160 
161   // Check to see if lhs and rhs have the same alias set
162   if (canAliasTypeSetsAlias(lhsTypes, rhsTypes)) {
163     if (lhsArg.alias_info() && rhsArg.alias_info()) {
164       for (const auto& lhsSet : lhsArg.alias_info()->afterSets()) {
165         for (const auto& rhsSet : rhsArg.alias_info()->afterSets()) {
166           if (lhsSet == rhsSet) {
167             return true;
168           }
169         }
170       }
171     }
172   }
173 
174   return false;
175 }
176 
may_contain_alias(const SchemaArgument & lhs,const SchemaArgument & rhs,bool bidirectional) const177 bool FunctionSchema::may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional) const {
178   bool may_alias_result = may_alias(lhs, rhs);
179   if (may_alias_result) {
180     return true;
181   }
182 
183   const c10::Argument lhsArg = getCorrectList(lhs.type)[lhs.index];
184   const c10::Argument rhsArg = getCorrectList(rhs.type)[rhs.index];
185   std::optional<AliasTypeSet> lhsTypes = mapTypeToAliasTypeSet(lhsArg.type());
186   std::optional<AliasTypeSet> rhsTypes = mapTypeToAliasTypeSet(rhsArg.type());
187   std::optional<AliasTypeSet> lhsContainedTypes = getAliasTypeSetContainedTypes(lhsTypes);
188   std::optional<AliasTypeSet> rhsContainedTypes = getAliasTypeSetContainedTypes(rhsTypes);
189 
190   // Checks if one side is wildcard and the other side is a container of the same type
191   bool lhsWildcard = lhsArg.alias_info() && lhsArg.alias_info()->isWildcardAfter() && canAliasTypeSetsAlias(lhsTypes, rhsContainedTypes);
192   bool rhsWildcard = rhsArg.alias_info() && rhsArg.alias_info()->isWildcardAfter() && canAliasTypeSetsAlias(rhsTypes, lhsContainedTypes);
193 
194   if (bidirectional) {
195     return lhsWildcard || rhsWildcard || canAliasTypeSetsAlias(lhsContainedTypes, rhsContainedTypes);
196   } else {
197     return rhsWildcard || canAliasTypeSetsAlias(lhsContainedTypes, rhsContainedTypes);
198   }
199 }
200 
operator <<(std::ostream & out,const FunctionSchema & schema)201 std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
202   // eventually this should look almost identical to python arg parser, but
203   // it is simpler for now to work directly on this schema
204 
205   out << schema.name();
206   if (!schema.overload_name().empty()) {
207     out << "." << schema.overload_name();
208   }
209   out << "(";
210 
211   bool seen_kwarg_only = false;
212   for (const auto i : c10::irange(schema.arguments().size())) {
213     if (i > 0) out << ", ";
214     if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
215       out << "*, ";
216       seen_kwarg_only = true;
217     }
218     out << schema.arguments()[i];
219   }
220 
221   if(schema.is_vararg()) {
222     if(!schema.arguments().empty())
223       out << ", ";
224     out << "...";
225   }
226 
227   out << ") -> ";
228 
229   const auto& returns = schema.returns();
230 
231   /*
232    * We should skip parenthesis if we return a single item and it's not varret,
233    * or we return nothing but varret.
234    *
235    * Need special handling for schema
236    *   aten::items.str(Dict(str, t) self) -> (str,t)[]
237    * Even though this schema returns a single item, we need add parenthesis.
238    * The is necessary so the printed schema can be parsed by the C++ SchemaParser
239    * Without the extra parenthesis, the parser sees the first parenthesis in '(str,t)' and mistakenly
240    * treat the return type as a tuple. An alternative is to enhance the Lexer
241    * to lookahead multiple tokens to accurately decide if the return type is
242    * a tuple.
243    */
244   bool need_paren = !(
245     (returns.size() == 1 && !schema.is_varret()) ||
246     (returns.empty() && schema.is_varret()));
247 
248   if (returns.size() == 1 && !schema.is_varret()) {
249     std::stringstream return_ss;
250     return_ss << returns.at(0);
251     auto return_str = return_ss.str();
252 
253     // enclosing the single return item with parenthesis if the return type
254     // starts with a left parenthesis.
255     //
256     // There are 2 cases
257     // 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'.
258     // without the extra parenthesis, the c++ schem parser can not parse it.
259     // 2. something like '-> ((str, str))'. Need extra parenthesis so the return
260     // type is a single tuple rather than two strings.
261     // PR (https://github.com/pytorch/pytorch/pull/23204) has more context about
262     // this. test_serialize_and_deserialize (https://github.com/pytorch/pytorch/blob/master/test/test_function_schema.py#L15)
263     // also covers this case.
264     if (!return_str.empty() && return_str.front() == '(') {
265       need_paren = true;
266     }
267   }
268 
269   if (need_paren) {
270     out << "(";
271   }
272   for (const auto i : c10::irange(returns.size())) {
273     if (i > 0) {
274       out << ", ";
275     }
276     out << returns.at(i);
277   }
278   if (schema.is_varret()) {
279     if (!returns.empty()) {
280       out << ", ";
281     }
282     out << "...";
283   }
284   if (need_paren) {
285     out << ")";
286   }
287   return out;
288 }
289 
findFirstOutArg(const std::vector<Argument> & args)290 static size_t findFirstOutArg(const std::vector<Argument>& args) {
291   // find the start of out args in the schema
292   for (const auto out_start_idx : c10::irange(args.size())) {
293     if (args.at(out_start_idx).is_out()) {
294       return out_start_idx;
295     }
296   }
297   return args.size();
298 }
299 
isBackwardCompatibleWith(const Argument & old,std::ostream * why_not) const300 bool Argument::isBackwardCompatibleWith(
301       const Argument& old,
302       std::ostream* why_not) const {
303     const Argument* lhs = this;
304     const Argument* rhs = &old;
305     if (!(lhs->name() == rhs->name()
306         && lhs->N() == rhs->N()
307           && (lhs->alias_info() == rhs->alias_info()
308               || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr
309                   && *lhs->alias_info() == *rhs->alias_info())))) {
310       return false;
311     }
312     if (lhs->kwarg_only() && !rhs->kwarg_only()) {
313       return false;
314     }
315     if (!rhs->type()->isSubtypeOfExt(*lhs->type(), why_not)) {
316       return false;
317     }
318     if (rhs->default_value().has_value() &&
319         lhs->default_value() != rhs->default_value()) {
320       return false;
321     }
322     return true;
323 }
324 
isForwardCompatibleWith(const Argument & old,std::ostream * why_not) const325 bool Argument::isForwardCompatibleWith(
326     const Argument& old,
327     std::ostream* why_not) const {
328   const Argument* lhs = this;
329   const Argument* rhs = &old;
330   if (!(lhs->name() == rhs->name()
331       && lhs->N() == rhs->N()
332         && (lhs->alias_info() == rhs->alias_info()
333             || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr
334                 && *lhs->alias_info() == *rhs->alias_info())))) {
335     return false;
336   }
337   if (lhs->kwarg_only() && !rhs->kwarg_only()) {
338     return false;
339   }
340   if (!lhs->type()->isSubtypeOfExt(rhs->type(), why_not)) {
341     return false;
342   }
343   if (rhs->default_value().has_value() &&
344       lhs->default_value() != rhs->default_value()) {
345     return false;
346   }
347   if (lhs->default_value().has_value() && !rhs->default_value().has_value()) {
348     return false;
349   }
350   return true;
351 }
352 
formatTypeMismatchMsg(const Argument & expected,const std::string & actual_type,std::optional<size_t> position,std::optional<std::string> value) const353 std::string FunctionSchema::formatTypeMismatchMsg(
354     const Argument& expected,
355     const std::string& actual_type,
356     std::optional<size_t> position,
357     std::optional<std::string> value) const {
358   std::string position_str;
359   if (position) {
360     position_str = c10::str("Position: ", *position, "\n");
361   }
362   std::string value_str;
363   if (value) {
364     value_str = c10::str("Value: ", *value, "\n");
365   }
366   return c10::str(
367       name(),
368       "() ",
369       expected.formatTypeMismatchMsg(actual_type),
370       position_str,
371       value_str,
372       "Declaration: ",
373       *this);
374 }
375 
isBackwardCompatibleWith(const FunctionSchema & old,std::ostream * why_not) const376 bool FunctionSchema::isBackwardCompatibleWith(
377     const FunctionSchema& old,
378     std::ostream* why_not) const {
379   if (!(name() == old.name()
380         && overload_name() == old.overload_name()
381         // we are conservative on is_vararg and is_varret,
382         // since they are only used by internal operators
383         && is_vararg() == old.is_vararg()
384         && is_varret() == old.is_varret()
385         && returns().size() == old.returns().size()
386         && arguments().size() >= old.arguments().size())) {
387     return false;
388   }
389   for (const auto i : c10::irange(returns().size())) {
390     // Backwards compatibility requires covariance on argument types
391     // (i.e. more generic), and contravariance on return types (i.e.
392     //  more specific).
393     if (!old.returns().at(i).isBackwardCompatibleWith(
394           returns().at(i),
395           why_not)) {
396       return false;
397     }
398   }
399 
400   // we want to test both out and default args separately
401   size_t old_out_start_idx = findFirstOutArg(old.arguments());
402   size_t new_out_start_idx = findFirstOutArg(arguments());
403 
404   // make sure among the default args, they are backward compatible
405   for (const auto i : c10::irange(old_out_start_idx)) {
406     if (!arguments().at(i).isBackwardCompatibleWith(
407           old.arguments().at(i), why_not)) {
408       return false;
409     }
410   }
411 
412   // Validate that all new arguments provided has a default value
413   for (const auto i : c10::irange(old_out_start_idx, new_out_start_idx)) {
414     if (!arguments().at(i).default_value()) {
415       if (why_not) {
416         *why_not
417             << "Function schema not backward compatible since the new argument '"
418             << arguments().at(i).name() << "' of type "
419             << arguments().at(i).type()->str()
420             << " did not provide a default value.";
421       }
422       return false;
423     }
424   }
425 
426   // now compare the out args
427   for (const auto i : c10::irange(old_out_start_idx, old.arguments().size())) {
428     if (!arguments()
429              .at(i - old_out_start_idx + new_out_start_idx)
430              .isBackwardCompatibleWith(old.arguments().at(i), why_not)) {
431       return false;
432     }
433   }
434 
435   return true;
436 }
437 
isForwardCompatibleWith(const FunctionSchema & old,std::ostringstream & why_not) const438 bool FunctionSchema::isForwardCompatibleWith(
439     const FunctionSchema& old,
440     std::ostringstream& why_not) const {
441   if (!(name() == old.name() &&
442         overload_name() == old.overload_name()
443         // we are conservative on is_vararg and is_varret,
444         // since they are only used by internal operators
445         && is_vararg() == old.is_vararg() && is_varret() == old.is_varret() &&
446         returns().size() == old.returns().size())) {
447     return false;
448   }
449 
450   // we want to test both out and default args separately
451   size_t old_out_start_idx = findFirstOutArg(old.arguments());
452   size_t new_out_start_idx = findFirstOutArg(arguments());
453 
454   if (old.arguments().size() - old_out_start_idx !=
455       arguments().size() - new_out_start_idx) {
456     if (why_not) {
457       why_not << "Function schema should have the "
458               << "same number of out arguments";
459     }
460     return false;
461   }
462 
463   // make sure among the default args, they are forward compatible
464   for (size_t i = 0; i < std::min(old_out_start_idx, new_out_start_idx); i++) {
465     if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) {
466       if (why_not) {
467         why_not
468             << "'" << arguments().at(i).name() << "'"
469             << " is not forward compatible with the older version of the schema";
470       }
471       return false;
472     }
473   }
474 
475   // Validate that all new arguments provided has a default value
476   for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) {
477     if (!arguments().at(i).default_value()) {
478       if (why_not) {
479         why_not
480             << "Function schema is not forward compatible since the new argument '"
481             << arguments().at(i).name() << "' of type "
482             << arguments().at(i).type()->str()
483             << " did not provide a default value.";
484       }
485       return false;
486     }
487 
488     auto default_val = arguments().at(i).default_value().value();
489     if (default_val.isList() || default_val.isGenericDict()) {
490       if (why_not) {
491         why_not
492             << "Function schema is not forward compatible since the new argument '"
493             << arguments().at(i).name() << "' of type "
494             << arguments().at(i).type()->str() << " has a container type "
495             << "as its default value.";
496       }
497       return false;
498     }
499   }
500 
501   // now compare the out args
502   for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) {
503     if (!arguments()
504              .at(i - old_out_start_idx + new_out_start_idx)
505              .isForwardCompatibleWith(old.arguments().at(i))) {
506       if (why_not) {
507         why_not << "Out argument '"
508                 << "'" << arguments().at(i).name()
509                 << " is not FC with the older version of the schema";
510       }
511       return false;
512     }
513   }
514 
515   return true;
516 }
517 
findErrorInKwargs(const std::vector<std::string> & kwargs) const518 std::string FunctionSchema::findErrorInKwargs(const std::vector<std::string>& kwargs) const {
519   // First check if any of the kwargs are unknown, i.e. don't match the name of
520   // any argument in the schema.
521   for (const auto& kwarg : kwargs) {
522     if (!std::count_if(
523             arguments().begin(),
524             arguments().end(),
525             [&kwarg](const Argument& argument) {
526               return argument.name() == kwarg;
527             })) {
528       return c10::str(
529           "Unknown keyword argument '",
530           kwarg,
531           "' for operator '",
532           name(),
533           "'. Schema: ",
534           *this);
535     }
536   }
537   // If there are unconsumed kwargs but none of them were unknown, the first
538   // positional argument present in the kwargs is duplicated.
539   for (const auto& argument : arguments()) {
540     if (std::find(kwargs.begin(), kwargs.end(), argument.name()) != kwargs.end()) {
541       AT_ASSERT(!argument.default_value());
542       return c10::str(
543           "Argument '",
544           argument.name(),
545           "' specified both as positional and ",
546           "keyword argument. Schema: ",
547           *this);
548     }
549   }
550   return "";
551 }
552 
553 
cloneWithRemappedTypes(const std::function<TypePtr (TypePtr)> type_map) const554 FunctionSchema FunctionSchema::cloneWithRemappedTypes(
555     const std::function<TypePtr(TypePtr)> type_map) const {
556   auto update_args = [&](const std::vector<Argument>& args) {
557     std::vector<Argument> new_args;
558     new_args.reserve(args.size());
559     for(const Argument& arg : args) {
560       new_args.emplace_back(arg.cloneWithType(type_map(arg.type())));
561     }
562     return new_args;
563   };
564   return FunctionSchema(
565       name(),
566       overload_name(),
567       update_args(arguments()),
568       update_args(returns()),
569       is_vararg(),
570       is_varret());
571 }
572 
573 // covariant subtyping of list of Arguments
isSubtypeOfList(ArrayRef<Argument> child,ArrayRef<Argument> parent,std::ostream * why_not)574 static bool isSubtypeOfList(
575     ArrayRef<Argument> child,
576     ArrayRef<Argument> parent,
577     std::ostream* why_not) {
578   if (child.size() != parent.size()) {
579     return false;
580   }
581   for (const auto i : c10::irange(child.size())) {
582     const Argument& c = child[i];
583     const Argument& p = parent[i];
584     if (c.name() != p.name()) {
585       return false;
586     }
587     if (!c.type()->isSubtypeOfExt(*p.type(), why_not)) {
588       return false;
589     }
590   }
591   return true;
592 }
593 
isSubtypeOf(const FunctionSchema & rhs,bool as_method,std::ostream * why_not) const594 bool FunctionSchema::isSubtypeOf(
595     const FunctionSchema& rhs,
596     bool as_method,
597     std::ostream* why_not) const {
598   size_t start = as_method ? 1 : 0;
599   // functions are contravariant in arguments but covariant in returns
600   return isSubtypeOfList(
601              ArrayRef<Argument>(rhs.arguments()).slice(start),
602              ArrayRef<Argument>(arguments()).slice(start),
603              why_not) &&
604       isSubtypeOfList(returns(), rhs.returns(), why_not);
605 }
606 
607 } // namespace c10
608