1 #pragma once
2
3 #include <c10/util/StringUtil.h>
4 #include <c10/util/string_view.h>
5 #include <c10/util/irange.h>
6 #include <ATen/core/jit_type.h>
7 #include <ATen/core/symbol.h>
8 #include <ATen/core/ivalue.h>
9 #include <ATen/core/alias_info.h>
10 #include <ATen/core/operator_name.h>
11 #include <ATen/core/dispatch/OperatorOptions.h>
12 #include <unordered_map>
13 #include <utility>
14
15 namespace c10 {
16
17 // schema as used in the compiler for resolving function calls and reporting
18 // errors. These objects should be constructed from C10 schema once those
19 // are available.
20
21 struct Argument;
22 struct FunctionSchema;
23
24 using AliasTypeSet = std::vector<TypePtr>;
25
26 bool operator==(const Argument& lhs, const Argument& rhs);
27
28 struct TORCH_API Argument {
29 Argument(
30 std::string name = "",
31 const TypePtr& type = nullptr,
32 std::optional<int32_t> N = std::nullopt,
33 std::optional<IValue> default_value = std::nullopt,
34 bool kwarg_only = false,
35 std::optional<AliasInfo> alias_info = std::nullopt)
ArgumentArgument36 : Argument(std::move(name), type, type, N, std::move(default_value), kwarg_only, std::move(alias_info)) {}
37
38 Argument(
39 std::string name,
40 TypePtr fake_type,
41 TypePtr real_type,
42 std::optional<int32_t> N = std::nullopt,
43 std::optional<IValue> default_value = std::nullopt,
44 bool kwarg_only = false,
45 std::optional<AliasInfo> alias_info = std::nullopt)
name_Argument46 : name_(std::move(name)),
47 type_(fake_type ? std::move(fake_type) : TensorType::get()),
48 real_type_(real_type ? std::move(real_type) : type_),
49 N_(N),
50 default_value_(std::move(default_value)),
51 alias_info_(alias_info ? std::make_unique<AliasInfo>(std::move(*alias_info)) : nullptr),
52 kwarg_only_(kwarg_only) {
53 // this is an softly-enforced invariant for out arguments.
54 bool is_alias = alias_info_ != nullptr && alias_info_->isWrite();
55 is_out_ = kwarg_only_ && is_alias;
56 }
57
58 Argument(Argument&& rhs) noexcept = default;
59
ArgumentArgument60 Argument(const Argument& rhs)
61 : name_(rhs.name_),
62 type_(rhs.type_),
63 real_type_(rhs.real_type_),
64 N_(rhs.N_),
65 default_value_(rhs.default_value_),
66 alias_info_(rhs.alias_info_ ? std::make_unique<AliasInfo>(*rhs.alias_info_) : nullptr),
67 kwarg_only_(rhs.kwarg_only_),
68 is_out_(rhs.is_out_) {}
69
70 Argument& operator=(Argument&& rhs) = default;
71
72 Argument& operator=(const Argument& rhs) {
73 if (this != &rhs) {
74 name_ = rhs.name_;
75 type_ = rhs.type_;
76 real_type_ = rhs.real_type_;
77 N_ = rhs.N_;
78 default_value_ = rhs.default_value_;
79 alias_info_ = rhs.alias_info_ ? std::make_unique<AliasInfo>(*rhs.alias_info_) : nullptr;
80 kwarg_only_ = rhs.kwarg_only_;
81 is_out_ = rhs.is_out_;
82 }
83 return *this;
84 }
85
nameArgument86 const std::string& name() const {
87 return name_;
88 }
typeArgument89 const TypePtr& type() const {
90 return type_;
91 }
92 // if type() is non-null, this is guaranteed to be non-null (if no real
93 // type was provided, this takes on type()'s value)
real_typeArgument94 const TypePtr& real_type() const {
95 return real_type_;
96 }
NArgument97 std::optional<int32_t> N() const {
98 return N_;
99 }
default_valueArgument100 const std::optional<IValue>& default_value() const {
101 return default_value_;
102 }
kwarg_onlyArgument103 bool kwarg_only() const {
104 return kwarg_only_;
105 }
106
is_outArgument107 bool is_out() const {
108 return is_out_;
109 }
110
alias_infoArgument111 C10_NODISCARD const AliasInfo* alias_info() const {
112 return alias_info_.get();
113 }
114
is_inferred_typeArgument115 bool is_inferred_type() const {
116 bool is_inferred_type = false;
117 TORCH_INTERNAL_ASSERT(type_);
118 if (auto pt = type_->cast<TensorType>()) {
119 if (pt->isInferredType()) {
120 is_inferred_type = true;
121 }
122 }
123 return is_inferred_type;
124 }
125
formatTypeMismatchMsgArgument126 std::string formatTypeMismatchMsg(const std::string& actual_type) const {
127 std::string inferred_type_hint;
128 if (is_inferred_type()) {
129 inferred_type_hint = c10::str(
130 "Inferred '",
131 name(),
132 "' to be of type 'Tensor' ",
133 "because it was not annotated with an explicit type.\n");
134 }
135 return c10::str(
136 "Expected a value of type '",
137 type()->repr_str(),
138 "' for argument '",
139 name(),
140 "' but instead found type '",
141 actual_type,
142 "'.\n",
143 inferred_type_hint);
144 }
145
cloneWithTypeArgument146 Argument cloneWithType(const TypePtr& new_type) const {
147 return Argument(
148 name_,
149 new_type,
150 N_,
151 default_value_,
152 kwarg_only_,
153 alias_info_ ? std::optional<AliasInfo>(*alias_info_) : std::nullopt);
154 }
155
156 // this function checks whether this Argument is backward compatible with
157 // the old one. we consider the following cases are backward compatible:
158 // 1) two arguments are equal
159 // 2) this arg's type should be subtype of old
160 // 3) this arg must provide the same default value if old arg has one,
161 bool isBackwardCompatibleWith(
162 const Argument& old,
163 std::ostream* why_not=nullptr) const;
164
165 // this function checks whether this Argument is forward compatible with
166 // the old one. we consider the following cases are forward compatible:
167 // 1) two arguments are equal
168 // 2) this arg's type should be subtype of old
169 // 3) this arg must provide the same default value if old arg has one,
170 bool isForwardCompatibleWith(
171 const Argument& old,
172 std::ostream* why_not = nullptr) const;
173
174 private:
175 std::string name_;
176 TypePtr type_;
177 TypePtr real_type_; // this is ScalarType, not int, e.g.
178 // for list types, an optional statically known length for the list
179 // e.g. for int[3]: type = ListType::ofInts(), N = 3
180 // If present, this will allow scalars to be broadcast to this length to
181 // become a list.
182 std::optional<int32_t> N_;
183
184 std::optional<IValue> default_value_;
185 // AliasInfo is huge, so let's only allocate memory for it if
186 // necessary (which it isn't during schema parsing on startup, to
187 // give a pertinent example).
188 std::unique_ptr<AliasInfo> alias_info_;
189 // is this only specifiable as a keyword argument?
190 bool kwarg_only_;
191 // marks if the argument is out variant of the schema
192 bool is_out_;
193 };
194
195 inline bool operator==(const Argument& lhs, const Argument& rhs) {
196 return lhs.name() == rhs.name()
197 && *lhs.type() == *rhs.type()
198 && lhs.N() == rhs.N()
199 && lhs.default_value() == rhs.default_value()
200 && lhs.kwarg_only() == rhs.kwarg_only()
201 && (lhs.alias_info() == rhs.alias_info()
202 || (lhs.alias_info() != nullptr && rhs.alias_info() != nullptr
203 && *lhs.alias_info() == *rhs.alias_info()));
204 }
205
206 inline bool operator!=(const Argument& lhs, const Argument& rhs) {
207 return !(lhs == rhs);
208 }
209
210 enum struct TORCH_API SchemaArgType { input, output };
211
212 /**
213 * struct SchemaArgument
214 *
215 * Structure used to represent arguments or returns for a schema.
216 */
217 struct TORCH_API SchemaArgument {
218 SchemaArgType type;
219 size_t index;
SchemaArgumentSchemaArgument220 SchemaArgument(SchemaArgType tpe, size_t idx) : type(tpe), index(idx) {}
221 bool operator==(const SchemaArgument& rhs) const {
222 return type == rhs.type && index == rhs.index;
223 }
224 };
225
226 bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs);
227
228 struct TORCH_API FunctionSchema {
229 FunctionSchema(
230 std::string name,
231 std::string overload_name,
232 std::vector<Argument> arguments,
233 std::vector<Argument> returns,
234 bool is_vararg = false,
235 bool is_varret = false)
236 : name_({std::move(name), std::move(overload_name)}),
237 arguments_(std::move(arguments)),
238 returns_(std::move(returns)),
239 is_vararg_(is_vararg),
240 is_varret_(is_varret) {
241 checkSchema();
242 }
243
244 FunctionSchema(
245 Symbol name,
246 std::string overload_name,
247 std::vector<Argument> arguments,
248 std::vector<Argument> returns,
249 bool is_vararg = false,
250 bool is_varret = false)
251 : FunctionSchema(
252 name.toQualString(),
253 std::move(overload_name),
254 std::move(arguments),
255 std::move(returns),
256 is_vararg,
257 is_varret) {
258 checkSchema();
259 }
260
261 // Checks whether this schema is backward compatible with the old one.
262 // The following conditions must be true:
263 // [Function structure] The new schema's name, overload-name, varargs, and
264 // return arity are the same.
265 // [Output Narrowing] The new schema's output type must be the same class
266 // or inherit from the old schema's output type.
267 // [Argument count] The new schema must have at least as many arguments as
268 // the old schema (considering the list of positional and kwargs).
269 // [Arg Compatibility] Every argument in the old schema has a corresponding
270 // argument in the new schema that:
271 // * is at the same position.
272 // * has the same name.
273 // * is either positional, or kwarg and the old argument was kwarg.
274 // * has the same type, or the old argument's type inherits from the
275 // new argument's type.
276 // [Default Values] Every new argument must have a default value.
277 // E.g.
278 // OK f_new(a, b, c=1) => f_old(a, b)
279 // NOK f_new(a, c=1, *, b) => f_old(a, *, b)
280 // OK f_new(a, b, *, c) => f_old(a, *, b, c)
281 // NOK f_new(a, *, b, c) -> f_old(a, b, *, c)
282 // NOK f_new(a, *, c, b) => f_old(a, *, b, c)
283 // OK f_new(a, *, b, c, d=1) => f_old(a, *, b, c)
284 bool isBackwardCompatibleWith(
285 const FunctionSchema& old,
286 std::ostream* why_not = nullptr) const;
287
288 // Checks whether this schema is forward compatible with the old one.
289 // The following conditions must be true:
290 // [Function structure] The new schema's name, overload-name, varargs, and
291 // return arity are the same.
292 // [Output Narrowing] The new schema's output type must be the same class
293 // or inherit from the old schema's output type.
294 // [Arg Compatibility] Every argument in the old schema has a corresponding
295 // argument in the new schema that:
296 // * is at the same position.
297 // * has the same name.
298 // * is either positional, or kwarg and the old argument was kwarg.
299 // * has the same type, or the old argument's type inherits from the
300 // new argument's type.
301 // [Default Values] Every new argument must have a default value.
302 // Each default value type should NOT be a container type.
303 // [Positioning] All defaults arguments MUST go after either old
304 // default arguments or the end of positional arguments
305 // and right BEFORE all out arguments
306 bool isForwardCompatibleWith(
307 const FunctionSchema& old,
308 std::ostringstream& why_not) const;
309
310 private:
311 OperatorName name_;
312 std::vector<Argument> arguments_;
313 std::vector<Argument> returns_;
314 // if true then this schema takes an arbitrary number of additional arguments
315 // after the argument specified in arguments
316 // currently this is used primarily to represent 'primitive' operators whose
317 // arguments are not checked by schema
318 bool is_vararg_;
319 bool is_varret_;
320
321 // if no alias information is directly specified, what kind of "default"
322 // alias information should we infer?
323 // NB: due to alias analysis kind merging, this may be nullopt. Eventually
324 // this should always be set no matter what
325 std::optional<AliasAnalysisKind> alias_kind_;
326
327 template <typename T>
328 void checkArg(const IValue& value, const Argument& argument, std::optional<size_t> pos) const;
329
checkSchemaFunctionSchema330 void checkSchema() const {
331 bool seen_default_arg = false;
332 for (const auto& arg : arguments()) {
333 if (arg.default_value()) {
334 seen_default_arg = true;
335 } else {
336 // we have historically serialized broadcasting lists wo/default values,
337 // so to not break BC allow lists here
338 if (arg.type()->kind() == ListType::Kind) {
339 continue;
340 }
341 TORCH_INTERNAL_ASSERT(
342 !seen_default_arg || arg.kwarg_only(),
343 "Non-default positional argument follows default argument. Parameter ",
344 arg.name(),
345 " in ",
346 *this);
347 }
348 }
349 }
350
351 public:
352
353 void dump() const;
354
operator_nameFunctionSchema355 const OperatorName& operator_name() const {
356 return name_;
357 }
nameFunctionSchema358 const std::string& name() const {
359 return name_.name;
360 }
overload_nameFunctionSchema361 const std::string& overload_name() const {
362 return name_.overload_name;
363 }
argumentsFunctionSchema364 const std::vector<Argument>& arguments() const {
365 return arguments_;
366 }
returnsFunctionSchema367 const std::vector<Argument>& returns() const {
368 return returns_;
369 }
is_varargFunctionSchema370 bool is_vararg() const {
371 return is_vararg_;
372 }
is_varretFunctionSchema373 bool is_varret() const {
374 return is_varret_;
375 }
is_aliasingFunctionSchema376 bool is_aliasing(const c10::SchemaArgument &argument) const {
377 TORCH_INTERNAL_ASSERT(
378 argument.index < getCorrectList(argument.type).size(),
379 "Invalid index for schema.");
380 const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info();
381 return aliasInfo;
382 }
is_mutableFunctionSchema383 bool is_mutable() const {
384 return std::any_of(
385 arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) {
386 const AliasInfo* aliasInfo = arg.alias_info();
387 return aliasInfo && aliasInfo->isWrite();
388 });
389 }
is_mutableFunctionSchema390 bool is_mutable(const c10::SchemaArgument &argument) const {
391 TORCH_INTERNAL_ASSERT(
392 argument.index < getCorrectList(argument.type).size(),
393 "Invalid index for schema.");
394 const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info();
395 return aliasInfo && aliasInfo->isWrite();
396 }
is_mutableFunctionSchema397 bool is_mutable(c10::string_view name) const {
398 std::optional<int> index = argumentIndexWithName(name);
399 TORCH_INTERNAL_ASSERT(
400 index != std::nullopt, "Schema has no argument named ", name);
401
402 return is_mutable({c10::SchemaArgType::input, static_cast<size_t>(*index)});
403 }
404
405 // Returns whether lhs and rhs may alias directly.
406 // This does not account for cases where lhs or rhs are a container that
407 // may contain elements that alias the other argument.
408 // FunctionSchema::may_contain_alias will include that functionality.
409 bool may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const;
410
411 // Returns whether lhs and rhs may alias directly or whether lhs/rhs are a container
412 // that may contain elements that alias the other argument.
413 // bidirectional = false only returns whether lhs may contain an alias of rhs
414 // while bidirectional = true returns both directions.
415 bool may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional = true) const;
416
417 // Returns whether the two AliasTypeSets contain any similarities
418 // ie: whether the two type sets can alias.
419 bool canAliasTypeSetsAlias(const std::optional<AliasTypeSet> &lhs, const std::optional<AliasTypeSet> &rhs) const;
420
421 // Recursively Finds all contained types within the AliasTypeSet.
422 std::optional<AliasTypeSet> getAliasTypeSetContainedTypes(const std::optional<AliasTypeSet> &aliasTypeSet) const;
423
424 // Similar to mapTypeToAliasTypeSet defined in alias_analysis.cpp.
425 // Used to map types to a type such that all types that can alias will be mapped to the same type.
426 // For example, calling this method on 'Optional[List[int]]' is the same as calling this method
427 // on 'List[int]'.
428 std::optional<AliasTypeSet> mapTypeToAliasTypeSet(const TypePtr& type) const;
429
430 // Returns either arguments() or returns() depending on the SchemaArgType
431 // output => returns(), input => arguments()
432 const std::vector<Argument>& getCorrectList(SchemaArgType type) const;
433
argumentIndexWithNameFunctionSchema434 std::optional<int> argumentIndexWithName(c10::string_view name) const {
435 for (const auto i : c10::irange(arguments().size())) {
436 if(name == arguments()[i].name())
437 return i;
438 }
439 return std::nullopt;
440 }
cloneWithNameFunctionSchema441 FunctionSchema cloneWithName(std::string name, std::string overload_name) const {
442 return FunctionSchema(
443 std::move(name),
444 std::move(overload_name),
445 arguments(),
446 returns(),
447 is_vararg(),
448 is_varret()
449 );
450 }
cloneWithArgumentsFunctionSchema451 FunctionSchema cloneWithArguments(std::vector<Argument> new_arguments) const {
452 return FunctionSchema(
453 name(),
454 overload_name(),
455 std::move(new_arguments),
456 returns(),
457 is_vararg(),
458 is_varret());
459 }
cloneWithReturnsFunctionSchema460 FunctionSchema cloneWithReturns(std::vector<Argument> new_returns) const {
461 return FunctionSchema(
462 name(),
463 overload_name(),
464 arguments(),
465 std::move(new_returns),
466 is_vararg(),
467 is_varret());
468 }
469
470 std::string formatTypeMismatchMsg(
471 const Argument& expected,
472 const std::string& actual_type,
473 std::optional<size_t> position = std::nullopt,
474 std::optional<std::string> value = std::nullopt) const;
475
476 FunctionSchema cloneWithRemappedTypes(
477 const std::function<TypePtr(TypePtr)> type_map) const;
478
479 FunctionSchema cloneWithRealTypes(bool with_symint=true) const;
480
481 // Check that inputs have the correct types and appends any missing default
482 // values.
483 template <typename T = c10::PlatformType>
484 void checkAndNormalizeInputs(
485 std::vector<IValue>& inputs,
486 const std::unordered_map<std::string, IValue>& kwargs =
487 std::unordered_map<std::string, IValue>{}) const;
488
489 std::string findErrorInKwargs(const std::vector<std::string>& kwargs) const;
490
hasAnyAliasInfoFunctionSchema491 bool hasAnyAliasInfo() const {
492 for (const auto& arg : arguments_) {
493 if (arg.alias_info() != nullptr) {
494 return true;
495 }
496 }
497 for (const auto& ret : returns_) {
498 if (ret.alias_info() != nullptr) {
499 return true;
500 }
501 }
502 return false;
503 }
504
505
506 // TODO remove the mutation here
isDefaultAliasAnalysisKindFunctionSchema507 bool isDefaultAliasAnalysisKind() const {
508 return !alias_kind_;
509 }
aliasAnalysisFunctionSchema510 AliasAnalysisKind aliasAnalysis() const {
511 return alias_kind_.value_or(AliasAnalysisKind::CONSERVATIVE);
512 }
setAliasAnalysisFunctionSchema513 void setAliasAnalysis(AliasAnalysisKind v) {
514 alias_kind_ = v;
515 }
516
getNamespaceFunctionSchema517 std::optional<c10::string_view> getNamespace() const {
518 return name_.getNamespace();
519 }
520
521 // Returns true if we successfully set the namespace (as there
522 // was none set, and false otherwise)
setNamespaceIfNotSetFunctionSchema523 bool setNamespaceIfNotSet(const char* ns) {
524 return name_.setNamespaceIfNotSet(ns);
525 }
526
527 // can a function with this schema be substituted for a function of rhs's
528 // schema and have the program typecheck?
529 // as_method - if true, treat this schema as a method and ignore
530 // the first argument, which will be the object in both cases
531 bool isSubtypeOf(const FunctionSchema& rhs, bool as_method, std::ostream* why_not=nullptr) const;
532 };
533
534 inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) {
535 return lhs.name() == rhs.name()
536 && lhs.overload_name() == rhs.overload_name()
537 && lhs.arguments() == rhs.arguments()
538 && lhs.returns() == rhs.returns()
539 && lhs.is_vararg() == rhs.is_vararg()
540 && lhs.is_varret() == rhs.is_varret();
541 }
542
543 inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) {
544 return !(lhs == rhs);
545 }
546
547 // print out Argument, which is compatible with FunctionSchema parser
548 // full format: Type(alias)? name=default_value
549 inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
550
551 // for adjusting the ? position.
552 // in schema, we have Tensor?(a!) input, and t(a!)?.
553 // however, t?(a!) doesn't work with schema parser.
554 // so we always use Type(alias)? format
555 // real_type versus fake_type: in order to be compatible with FunctionSchema
556 // parser, printing an argument with either MemoryFormat or Layout type should
557 // give us the original schema string, hence printing out real_type.
558 auto type = arg.real_type();
559 bool is_opt = type->kind() == OptionalType::Kind;
560 auto unopt_type = is_opt ? type->castRaw<OptionalType>()->getElementType() : type;
561
562 if (unopt_type->kind() == ListType::Kind) {
563 // sized lists get size N from arg, not type
564 auto list = unopt_type->cast<c10::ListType>();
565 out << list->getElementType()->str();
566 if (arg.alias_info() && !arg.alias_info()->containedTypes().empty()){
567 out << arg.alias_info()->containedTypes()[0];
568 }
569 std::string N = "";
570 if (arg.N()) {
571 N = std::to_string(*arg.N());
572 }
573 out << "[" << N << "]";
574 } else {
575 out << unopt_type->str();
576 }
577
578 // print alias info if it has beforeSets.
579 if (arg.alias_info() && !arg.alias_info()->beforeSets().empty()) {
580 out << *arg.alias_info();
581 }
582
583 if (is_opt) {
584 out << "?";
585 }
586
587 if (!arg.name().empty()) {
588 out << " " << arg.name();
589 }
590
591 if (arg.default_value()) {
592 out << "=";
593 if ((type->kind() == c10::TypeKind::StringType ||
594 unopt_type->kind() == c10::TypeKind::StringType) &&
595 arg.default_value().value().isString()) {
596 printQuotedString(out, arg.default_value().value().toStringRef());
597 } else if (type->kind() == TypeKind::ListType && type->castRaw<ListType>()->getElementType()->kind() == c10::TypeKind::IntType) {
598 // We want to faithfully replicate JIT schema.
599 // in native_functions.yaml defaults for int arrays with a single value always look like
600 // int[2] stride=1
601 // instead of
602 // int[2] stride=[1, 1]
603 auto default_val = arg.default_value().value().toIntList();
604 if (default_val.size() > 1) {
605 auto all_defaults_the_same = true;
606 for (const auto i : c10::irange(1, default_val.size())) {
607 if (default_val[0] != default_val[i]) all_defaults_the_same = false;
608 }
609 if (all_defaults_the_same) {
610 out << default_val[0];
611 } else {
612 out << arg.default_value().value();
613 }
614 } else {
615 out << arg.default_value().value();
616 }
617 } else {
618 out << arg.default_value().value();
619 }
620 }
621
622 return out;
623 }
624
625 TORCH_API std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema);
626
toString(const FunctionSchema & schema)627 inline std::string toString(const FunctionSchema& schema) {
628 std::ostringstream str;
629 str << schema;
630 return str.str();
631 }
632
633 } // namespace c10
634
635 namespace std {
636 template<>
637 struct hash<c10::SchemaArgument> {
638 size_t operator()(const c10::SchemaArgument& arg) const
639 {
640 return c10::hash_combine(std::hash<size_t>()(arg.index), std::hash<size_t>()(static_cast<std::size_t>(arg.type)));
641 }
642 };
643 template<>
644 struct hash<c10::Argument> {
645 size_t operator()(const c10::Argument& arg) const
646 {
647 auto hash = std::hash<std::string>{}(arg.name());
648 auto type_hash = std::hash<c10::TypePtr>{}(arg.type());
649 auto kwarg_only_hash = std::hash<bool>{}(arg.kwarg_only());
650 hash = c10::hash_combine(hash, type_hash);
651 hash = c10::hash_combine(hash, kwarg_only_hash);
652 // hashing optional fields if they exist
653 if (arg.default_value()) {
654 auto default_value_hash = c10::hash<c10::IValue>{}(arg.default_value().value());
655 hash = c10::hash_combine(hash, default_value_hash);
656 }
657 if (arg.N()) {
658 auto N_hash = std::hash<int64_t>{}(*arg.N());
659 hash = c10::hash_combine(hash, N_hash);
660 }
661 if (arg.alias_info()) {
662 auto alias_info_hash = std::hash<c10::AliasInfo>{}(*arg.alias_info());
663 hash = c10::hash_combine(hash, alias_info_hash);
664 }
665 return hash;
666 }
667 };
668 template<>
669 struct hash<c10::FunctionSchema> {
670 size_t operator()(const c10::FunctionSchema& schema) const
671 {
672 auto hash = std::hash<c10::OperatorName>{}(schema.operator_name());
673 auto args_hash = c10::hash<std::vector<c10::Argument>>{}(schema.arguments());
674 auto returns_hash = c10::hash<std::vector<c10::Argument>>{}(schema.returns());
675 auto is_vararg_hash = std::hash<bool>{}(schema.is_vararg());
676 auto is_varret_hash = std::hash<bool>{}(schema.is_varret());
677 hash = c10::hash_combine(hash, args_hash);
678 hash = c10::hash_combine(hash, returns_hash);
679 hash = c10::hash_combine(hash, is_vararg_hash);
680 hash = c10::hash_combine(hash, is_varret_hash);
681 return hash;
682 }
683 };
684 } // namespace std
685
686
687 #include <ATen/core/function_schema_inl.h> // IWYU pragma: keep
688