1 #include <ATen/core/function_schema.h>
2
3 #include <iostream>
4 #include <stack>
5 #include <utility>
6
7 namespace c10 {
8
dump() const9 void FunctionSchema::dump() const {
10 std::cout << *this << "\n";
11 }
12
getCorrectList(SchemaArgType type) const13 const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type) const {
14 if (type == SchemaArgType::input) {
15 return arguments();
16 } else {
17 return returns();
18 }
19 }
20
cloneWithRealTypes(bool with_symint) const21 FunctionSchema FunctionSchema::cloneWithRealTypes(bool with_symint) const {
22 auto alwaysCloneWithRealTypes = [&](const Argument& a) {
23 return a.cloneWithType(a.real_type());
24 };
25 auto cloneWithRealTypes = [&](const Argument& a) {
26 if (with_symint) {
27 return a.cloneWithType(a.real_type());
28 }
29 // Don't use real type if it looks like a SymInt
30 // NB: keep this in sync with unpackSymInt in KernelFunction_impl.h
31 if (
32 *a.real_type() == *getTypePtr<c10::SymInt>() ||
33 *a.real_type() == *getTypePtr<std::optional<c10::SymInt>>() ||
34 *a.real_type() == *getTypePtr<c10::SymIntArrayRef>() ||
35 *a.real_type() == *getTypePtr<at::OptionalSymIntArrayRef>()
36 ) {
37 // Keep the fake type
38 return a.cloneWithType(a.type());
39 } else {
40 return a.cloneWithType(a.real_type());
41 }
42 };
43 std::vector<Argument> new_arguments, new_returns;
44 std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes);
45 // NB: SymInt returns are always SymInt
46 std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), alwaysCloneWithRealTypes);
47 return FunctionSchema(
48 name(),
49 overload_name(),
50 std::move(new_arguments),
51 std::move(new_returns),
52 is_vararg(),
53 is_varret());
54 }
55
canAliasTypeSetsAlias(const std::optional<AliasTypeSet> & lhs,const std::optional<AliasTypeSet> & rhs) const56 bool FunctionSchema::canAliasTypeSetsAlias(const std::optional<AliasTypeSet> &lhs, const std::optional<AliasTypeSet> &rhs) const {
57 if (!lhs || !rhs) {
58 return false;
59 }
60 for (const TypePtr& lhsType : *lhs) {
61 for (const TypePtr& rhsType : *rhs) {
62 if (lhsType == rhsType) {
63 return true;
64 }
65 }
66 }
67 return false;
68 }
69
getAliasTypeSetContainedTypes(const std::optional<AliasTypeSet> & aliasTypeSet) const70 std::optional<AliasTypeSet> FunctionSchema::getAliasTypeSetContainedTypes(const std::optional<AliasTypeSet> &aliasTypeSet) const {
71 if (!aliasTypeSet) {
72 return std::nullopt;
73 }
74 std::unordered_set<TypePtr> containedTypes;
75 std::stack<TypePtr> typeStack;
76 // Push all 1st level contained types into the stack.
77 for (const TypePtr& type: *aliasTypeSet) {
78 for (const TypePtr& containedType : type->containedTypes()){
79 typeStack.push(containedType);
80 }
81 }
82
83 // process all further level contained types.
84 while (!typeStack.empty()) {
85 TypePtr current = typeStack.top();
86 typeStack.pop();
87 if (!containedTypes.count(current)) {
88 for (const TypePtr& containedType : current->containedTypes()) {
89 typeStack.push(containedType);
90 }
91 }
92 containedTypes.insert(current);
93 }
94
95 return AliasTypeSet(containedTypes.begin(), containedTypes.end());
96 }
97
mapTypeToAliasTypeSet(const TypePtr & type) const98 std::optional<AliasTypeSet> FunctionSchema::mapTypeToAliasTypeSet(const TypePtr& type) const {
99 switch(type->kind()) {
100 case TypeKind::ListType:
101 case TypeKind::DictType:
102 case TypeKind::ClassType:
103 case TypeKind::TensorType:
104 return AliasTypeSet {c10::unshapedType(type)};
105 case TypeKind::UnionType: {
106 AliasTypeSet mutable_types;
107 for (const TypePtr& inner :
108 type->expectRef<UnionType>().containedTypes()) {
109 if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) {
110 mutable_types.insert(
111 mutable_types.end(),
112 (*maybe_inner_types).begin(),
113 (*maybe_inner_types).end());
114 }
115 }
116 if (mutable_types.empty()) {
117 return std::nullopt;
118 }
119 return mutable_types;
120 }
121 case TypeKind::AnyType:
122 return {AliasTypeSet{type}};
123 case TypeKind::OptionalType: {
124 auto inner = type->castRaw<OptionalType>()->getElementType();
125 return mapTypeToAliasTypeSet(inner);
126 }
127 case TypeKind::TupleType: {
128 AliasTypeSet mutable_types;
129 for (const TypePtr& inner : type->expectRef<TupleType>().elements()) {
130 if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) {
131 mutable_types.insert(
132 mutable_types.end(),
133 (*maybe_inner_types).begin(),
134 (*maybe_inner_types).end());
135 }
136 }
137 if (mutable_types.empty()) {
138 return std::nullopt;
139 }
140 return {AliasTypeSet{TupleType::create(std::move(mutable_types))}};
141 }
142 default:
143 return std::nullopt;
144 }
145 }
146
may_alias(const SchemaArgument & lhs,const SchemaArgument & rhs) const147 bool FunctionSchema::may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const {
148 TORCH_INTERNAL_ASSERT(
149 (lhs.index < getCorrectList(lhs.type).size()),
150 "Invalid index for schema.");
151 TORCH_INTERNAL_ASSERT(
152 (rhs.index < getCorrectList(rhs.type).size()),
153 "Invalid index for schema.");
154
155 const Argument lhsArg = getCorrectList(lhs.type)[lhs.index];
156 const Argument rhsArg = getCorrectList(rhs.type)[rhs.index];
157
158 std::optional<AliasTypeSet> lhsTypes = mapTypeToAliasTypeSet(lhsArg.type());
159 std::optional<AliasTypeSet> rhsTypes = mapTypeToAliasTypeSet(rhsArg.type());
160
161 // Check to see if lhs and rhs have the same alias set
162 if (canAliasTypeSetsAlias(lhsTypes, rhsTypes)) {
163 if (lhsArg.alias_info() && rhsArg.alias_info()) {
164 for (const auto& lhsSet : lhsArg.alias_info()->afterSets()) {
165 for (const auto& rhsSet : rhsArg.alias_info()->afterSets()) {
166 if (lhsSet == rhsSet) {
167 return true;
168 }
169 }
170 }
171 }
172 }
173
174 return false;
175 }
176
may_contain_alias(const SchemaArgument & lhs,const SchemaArgument & rhs,bool bidirectional) const177 bool FunctionSchema::may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional) const {
178 bool may_alias_result = may_alias(lhs, rhs);
179 if (may_alias_result) {
180 return true;
181 }
182
183 const c10::Argument lhsArg = getCorrectList(lhs.type)[lhs.index];
184 const c10::Argument rhsArg = getCorrectList(rhs.type)[rhs.index];
185 std::optional<AliasTypeSet> lhsTypes = mapTypeToAliasTypeSet(lhsArg.type());
186 std::optional<AliasTypeSet> rhsTypes = mapTypeToAliasTypeSet(rhsArg.type());
187 std::optional<AliasTypeSet> lhsContainedTypes = getAliasTypeSetContainedTypes(lhsTypes);
188 std::optional<AliasTypeSet> rhsContainedTypes = getAliasTypeSetContainedTypes(rhsTypes);
189
190 // Checks if one side is wildcard and the other side is a container of the same type
191 bool lhsWildcard = lhsArg.alias_info() && lhsArg.alias_info()->isWildcardAfter() && canAliasTypeSetsAlias(lhsTypes, rhsContainedTypes);
192 bool rhsWildcard = rhsArg.alias_info() && rhsArg.alias_info()->isWildcardAfter() && canAliasTypeSetsAlias(rhsTypes, lhsContainedTypes);
193
194 if (bidirectional) {
195 return lhsWildcard || rhsWildcard || canAliasTypeSetsAlias(lhsContainedTypes, rhsContainedTypes);
196 } else {
197 return rhsWildcard || canAliasTypeSetsAlias(lhsContainedTypes, rhsContainedTypes);
198 }
199 }
200
operator <<(std::ostream & out,const FunctionSchema & schema)201 std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
202 // eventually this should look almost identical to python arg parser, but
203 // it is simpler for now to work directly on this schema
204
205 out << schema.name();
206 if (!schema.overload_name().empty()) {
207 out << "." << schema.overload_name();
208 }
209 out << "(";
210
211 bool seen_kwarg_only = false;
212 for (const auto i : c10::irange(schema.arguments().size())) {
213 if (i > 0) out << ", ";
214 if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
215 out << "*, ";
216 seen_kwarg_only = true;
217 }
218 out << schema.arguments()[i];
219 }
220
221 if(schema.is_vararg()) {
222 if(!schema.arguments().empty())
223 out << ", ";
224 out << "...";
225 }
226
227 out << ") -> ";
228
229 const auto& returns = schema.returns();
230
231 /*
232 * We should skip parenthesis if we return a single item and it's not varret,
233 * or we return nothing but varret.
234 *
235 * Need special handling for schema
236 * aten::items.str(Dict(str, t) self) -> (str,t)[]
237 * Even though this schema returns a single item, we need add parenthesis.
238 * The is necessary so the printed schema can be parsed by the C++ SchemaParser
239 * Without the extra parenthesis, the parser sees the first parenthesis in '(str,t)' and mistakenly
240 * treat the return type as a tuple. An alternative is to enhance the Lexer
241 * to lookahead multiple tokens to accurately decide if the return type is
242 * a tuple.
243 */
244 bool need_paren = !(
245 (returns.size() == 1 && !schema.is_varret()) ||
246 (returns.empty() && schema.is_varret()));
247
248 if (returns.size() == 1 && !schema.is_varret()) {
249 std::stringstream return_ss;
250 return_ss << returns.at(0);
251 auto return_str = return_ss.str();
252
253 // enclosing the single return item with parenthesis if the return type
254 // starts with a left parenthesis.
255 //
256 // There are 2 cases
257 // 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'.
258 // without the extra parenthesis, the c++ schem parser can not parse it.
259 // 2. something like '-> ((str, str))'. Need extra parenthesis so the return
260 // type is a single tuple rather than two strings.
261 // PR (https://github.com/pytorch/pytorch/pull/23204) has more context about
262 // this. test_serialize_and_deserialize (https://github.com/pytorch/pytorch/blob/master/test/test_function_schema.py#L15)
263 // also covers this case.
264 if (!return_str.empty() && return_str.front() == '(') {
265 need_paren = true;
266 }
267 }
268
269 if (need_paren) {
270 out << "(";
271 }
272 for (const auto i : c10::irange(returns.size())) {
273 if (i > 0) {
274 out << ", ";
275 }
276 out << returns.at(i);
277 }
278 if (schema.is_varret()) {
279 if (!returns.empty()) {
280 out << ", ";
281 }
282 out << "...";
283 }
284 if (need_paren) {
285 out << ")";
286 }
287 return out;
288 }
289
findFirstOutArg(const std::vector<Argument> & args)290 static size_t findFirstOutArg(const std::vector<Argument>& args) {
291 // find the start of out args in the schema
292 for (const auto out_start_idx : c10::irange(args.size())) {
293 if (args.at(out_start_idx).is_out()) {
294 return out_start_idx;
295 }
296 }
297 return args.size();
298 }
299
isBackwardCompatibleWith(const Argument & old,std::ostream * why_not) const300 bool Argument::isBackwardCompatibleWith(
301 const Argument& old,
302 std::ostream* why_not) const {
303 const Argument* lhs = this;
304 const Argument* rhs = &old;
305 if (!(lhs->name() == rhs->name()
306 && lhs->N() == rhs->N()
307 && (lhs->alias_info() == rhs->alias_info()
308 || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr
309 && *lhs->alias_info() == *rhs->alias_info())))) {
310 return false;
311 }
312 if (lhs->kwarg_only() && !rhs->kwarg_only()) {
313 return false;
314 }
315 if (!rhs->type()->isSubtypeOfExt(*lhs->type(), why_not)) {
316 return false;
317 }
318 if (rhs->default_value().has_value() &&
319 lhs->default_value() != rhs->default_value()) {
320 return false;
321 }
322 return true;
323 }
324
isForwardCompatibleWith(const Argument & old,std::ostream * why_not) const325 bool Argument::isForwardCompatibleWith(
326 const Argument& old,
327 std::ostream* why_not) const {
328 const Argument* lhs = this;
329 const Argument* rhs = &old;
330 if (!(lhs->name() == rhs->name()
331 && lhs->N() == rhs->N()
332 && (lhs->alias_info() == rhs->alias_info()
333 || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr
334 && *lhs->alias_info() == *rhs->alias_info())))) {
335 return false;
336 }
337 if (lhs->kwarg_only() && !rhs->kwarg_only()) {
338 return false;
339 }
340 if (!lhs->type()->isSubtypeOfExt(rhs->type(), why_not)) {
341 return false;
342 }
343 if (rhs->default_value().has_value() &&
344 lhs->default_value() != rhs->default_value()) {
345 return false;
346 }
347 if (lhs->default_value().has_value() && !rhs->default_value().has_value()) {
348 return false;
349 }
350 return true;
351 }
352
formatTypeMismatchMsg(const Argument & expected,const std::string & actual_type,std::optional<size_t> position,std::optional<std::string> value) const353 std::string FunctionSchema::formatTypeMismatchMsg(
354 const Argument& expected,
355 const std::string& actual_type,
356 std::optional<size_t> position,
357 std::optional<std::string> value) const {
358 std::string position_str;
359 if (position) {
360 position_str = c10::str("Position: ", *position, "\n");
361 }
362 std::string value_str;
363 if (value) {
364 value_str = c10::str("Value: ", *value, "\n");
365 }
366 return c10::str(
367 name(),
368 "() ",
369 expected.formatTypeMismatchMsg(actual_type),
370 position_str,
371 value_str,
372 "Declaration: ",
373 *this);
374 }
375
isBackwardCompatibleWith(const FunctionSchema & old,std::ostream * why_not) const376 bool FunctionSchema::isBackwardCompatibleWith(
377 const FunctionSchema& old,
378 std::ostream* why_not) const {
379 if (!(name() == old.name()
380 && overload_name() == old.overload_name()
381 // we are conservative on is_vararg and is_varret,
382 // since they are only used by internal operators
383 && is_vararg() == old.is_vararg()
384 && is_varret() == old.is_varret()
385 && returns().size() == old.returns().size()
386 && arguments().size() >= old.arguments().size())) {
387 return false;
388 }
389 for (const auto i : c10::irange(returns().size())) {
390 // Backwards compatibility requires covariance on argument types
391 // (i.e. more generic), and contravariance on return types (i.e.
392 // more specific).
393 if (!old.returns().at(i).isBackwardCompatibleWith(
394 returns().at(i),
395 why_not)) {
396 return false;
397 }
398 }
399
400 // we want to test both out and default args separately
401 size_t old_out_start_idx = findFirstOutArg(old.arguments());
402 size_t new_out_start_idx = findFirstOutArg(arguments());
403
404 // make sure among the default args, they are backward compatible
405 for (const auto i : c10::irange(old_out_start_idx)) {
406 if (!arguments().at(i).isBackwardCompatibleWith(
407 old.arguments().at(i), why_not)) {
408 return false;
409 }
410 }
411
412 // Validate that all new arguments provided has a default value
413 for (const auto i : c10::irange(old_out_start_idx, new_out_start_idx)) {
414 if (!arguments().at(i).default_value()) {
415 if (why_not) {
416 *why_not
417 << "Function schema not backward compatible since the new argument '"
418 << arguments().at(i).name() << "' of type "
419 << arguments().at(i).type()->str()
420 << " did not provide a default value.";
421 }
422 return false;
423 }
424 }
425
426 // now compare the out args
427 for (const auto i : c10::irange(old_out_start_idx, old.arguments().size())) {
428 if (!arguments()
429 .at(i - old_out_start_idx + new_out_start_idx)
430 .isBackwardCompatibleWith(old.arguments().at(i), why_not)) {
431 return false;
432 }
433 }
434
435 return true;
436 }
437
isForwardCompatibleWith(const FunctionSchema & old,std::ostringstream & why_not) const438 bool FunctionSchema::isForwardCompatibleWith(
439 const FunctionSchema& old,
440 std::ostringstream& why_not) const {
441 if (!(name() == old.name() &&
442 overload_name() == old.overload_name()
443 // we are conservative on is_vararg and is_varret,
444 // since they are only used by internal operators
445 && is_vararg() == old.is_vararg() && is_varret() == old.is_varret() &&
446 returns().size() == old.returns().size())) {
447 return false;
448 }
449
450 // we want to test both out and default args separately
451 size_t old_out_start_idx = findFirstOutArg(old.arguments());
452 size_t new_out_start_idx = findFirstOutArg(arguments());
453
454 if (old.arguments().size() - old_out_start_idx !=
455 arguments().size() - new_out_start_idx) {
456 if (why_not) {
457 why_not << "Function schema should have the "
458 << "same number of out arguments";
459 }
460 return false;
461 }
462
463 // make sure among the default args, they are forward compatible
464 for (size_t i = 0; i < std::min(old_out_start_idx, new_out_start_idx); i++) {
465 if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) {
466 if (why_not) {
467 why_not
468 << "'" << arguments().at(i).name() << "'"
469 << " is not forward compatible with the older version of the schema";
470 }
471 return false;
472 }
473 }
474
475 // Validate that all new arguments provided has a default value
476 for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) {
477 if (!arguments().at(i).default_value()) {
478 if (why_not) {
479 why_not
480 << "Function schema is not forward compatible since the new argument '"
481 << arguments().at(i).name() << "' of type "
482 << arguments().at(i).type()->str()
483 << " did not provide a default value.";
484 }
485 return false;
486 }
487
488 auto default_val = arguments().at(i).default_value().value();
489 if (default_val.isList() || default_val.isGenericDict()) {
490 if (why_not) {
491 why_not
492 << "Function schema is not forward compatible since the new argument '"
493 << arguments().at(i).name() << "' of type "
494 << arguments().at(i).type()->str() << " has a container type "
495 << "as its default value.";
496 }
497 return false;
498 }
499 }
500
501 // now compare the out args
502 for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) {
503 if (!arguments()
504 .at(i - old_out_start_idx + new_out_start_idx)
505 .isForwardCompatibleWith(old.arguments().at(i))) {
506 if (why_not) {
507 why_not << "Out argument '"
508 << "'" << arguments().at(i).name()
509 << " is not FC with the older version of the schema";
510 }
511 return false;
512 }
513 }
514
515 return true;
516 }
517
findErrorInKwargs(const std::vector<std::string> & kwargs) const518 std::string FunctionSchema::findErrorInKwargs(const std::vector<std::string>& kwargs) const {
519 // First check if any of the kwargs are unknown, i.e. don't match the name of
520 // any argument in the schema.
521 for (const auto& kwarg : kwargs) {
522 if (!std::count_if(
523 arguments().begin(),
524 arguments().end(),
525 [&kwarg](const Argument& argument) {
526 return argument.name() == kwarg;
527 })) {
528 return c10::str(
529 "Unknown keyword argument '",
530 kwarg,
531 "' for operator '",
532 name(),
533 "'. Schema: ",
534 *this);
535 }
536 }
537 // If there are unconsumed kwargs but none of them were unknown, the first
538 // positional argument present in the kwargs is duplicated.
539 for (const auto& argument : arguments()) {
540 if (std::find(kwargs.begin(), kwargs.end(), argument.name()) != kwargs.end()) {
541 AT_ASSERT(!argument.default_value());
542 return c10::str(
543 "Argument '",
544 argument.name(),
545 "' specified both as positional and ",
546 "keyword argument. Schema: ",
547 *this);
548 }
549 }
550 return "";
551 }
552
553
cloneWithRemappedTypes(const std::function<TypePtr (TypePtr)> type_map) const554 FunctionSchema FunctionSchema::cloneWithRemappedTypes(
555 const std::function<TypePtr(TypePtr)> type_map) const {
556 auto update_args = [&](const std::vector<Argument>& args) {
557 std::vector<Argument> new_args;
558 new_args.reserve(args.size());
559 for(const Argument& arg : args) {
560 new_args.emplace_back(arg.cloneWithType(type_map(arg.type())));
561 }
562 return new_args;
563 };
564 return FunctionSchema(
565 name(),
566 overload_name(),
567 update_args(arguments()),
568 update_args(returns()),
569 is_vararg(),
570 is_varret());
571 }
572
573 // covariant subtyping of list of Arguments
isSubtypeOfList(ArrayRef<Argument> child,ArrayRef<Argument> parent,std::ostream * why_not)574 static bool isSubtypeOfList(
575 ArrayRef<Argument> child,
576 ArrayRef<Argument> parent,
577 std::ostream* why_not) {
578 if (child.size() != parent.size()) {
579 return false;
580 }
581 for (const auto i : c10::irange(child.size())) {
582 const Argument& c = child[i];
583 const Argument& p = parent[i];
584 if (c.name() != p.name()) {
585 return false;
586 }
587 if (!c.type()->isSubtypeOfExt(*p.type(), why_not)) {
588 return false;
589 }
590 }
591 return true;
592 }
593
isSubtypeOf(const FunctionSchema & rhs,bool as_method,std::ostream * why_not) const594 bool FunctionSchema::isSubtypeOf(
595 const FunctionSchema& rhs,
596 bool as_method,
597 std::ostream* why_not) const {
598 size_t start = as_method ? 1 : 0;
599 // functions are contravariant in arguments but covariant in returns
600 return isSubtypeOfList(
601 ArrayRef<Argument>(rhs.arguments()).slice(start),
602 ArrayRef<Argument>(arguments()).slice(start),
603 why_not) &&
604 isSubtypeOfList(returns(), rhs.returns(), why_not);
605 }
606
607 } // namespace c10
608