xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/function_schema.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/StringUtil.h>
4 #include <c10/util/string_view.h>
5 #include <c10/util/irange.h>
6 #include <ATen/core/jit_type.h>
7 #include <ATen/core/symbol.h>
8 #include <ATen/core/ivalue.h>
9 #include <ATen/core/alias_info.h>
10 #include <ATen/core/operator_name.h>
11 #include <ATen/core/dispatch/OperatorOptions.h>
12 #include <unordered_map>
13 #include <utility>
14 
15 namespace c10 {
16 
17 // schema as used in the compiler for resolving function calls and reporting
18 // errors. These objects should be constructed from C10 schema once those
19 // are available.
20 
21 struct Argument;
22 struct FunctionSchema;
23 
24 using AliasTypeSet = std::vector<TypePtr>;
25 
26 bool operator==(const Argument& lhs, const Argument& rhs);
27 
28 struct TORCH_API Argument {
29   Argument(
30       std::string name = "",
31       const TypePtr& type = nullptr,
32       std::optional<int32_t> N = std::nullopt,
33       std::optional<IValue> default_value = std::nullopt,
34       bool kwarg_only = false,
35       std::optional<AliasInfo> alias_info = std::nullopt)
ArgumentArgument36     : Argument(std::move(name), type, type, N, std::move(default_value), kwarg_only, std::move(alias_info)) {}
37 
38   Argument(
39       std::string name,
40       TypePtr fake_type,
41       TypePtr real_type,
42       std::optional<int32_t> N = std::nullopt,
43       std::optional<IValue> default_value = std::nullopt,
44       bool kwarg_only = false,
45       std::optional<AliasInfo> alias_info = std::nullopt)
name_Argument46       : name_(std::move(name)),
47         type_(fake_type ? std::move(fake_type) : TensorType::get()),
48         real_type_(real_type ? std::move(real_type) : type_),
49         N_(N),
50         default_value_(std::move(default_value)),
51         alias_info_(alias_info ? std::make_unique<AliasInfo>(std::move(*alias_info)) : nullptr),
52         kwarg_only_(kwarg_only) {
53     // this is an softly-enforced invariant for out arguments.
54     bool is_alias = alias_info_ != nullptr && alias_info_->isWrite();
55     is_out_ = kwarg_only_ && is_alias;
56   }
57 
58   Argument(Argument&& rhs) noexcept = default;
59 
ArgumentArgument60   Argument(const Argument& rhs)
61       : name_(rhs.name_),
62         type_(rhs.type_),
63         real_type_(rhs.real_type_),
64         N_(rhs.N_),
65         default_value_(rhs.default_value_),
66         alias_info_(rhs.alias_info_ ? std::make_unique<AliasInfo>(*rhs.alias_info_) : nullptr),
67         kwarg_only_(rhs.kwarg_only_),
68         is_out_(rhs.is_out_) {}
69 
70   Argument& operator=(Argument&& rhs) = default;
71 
72   Argument& operator=(const Argument& rhs) {
73     if (this != &rhs) {
74       name_ = rhs.name_;
75       type_ = rhs.type_;
76       real_type_ = rhs.real_type_;
77       N_ = rhs.N_;
78       default_value_ = rhs.default_value_;
79       alias_info_ = rhs.alias_info_ ? std::make_unique<AliasInfo>(*rhs.alias_info_) : nullptr;
80       kwarg_only_ = rhs.kwarg_only_;
81       is_out_ = rhs.is_out_;
82     }
83     return *this;
84   }
85 
nameArgument86   const std::string& name() const {
87     return name_;
88   }
typeArgument89   const TypePtr& type() const {
90     return type_;
91   }
92   // if type() is non-null, this is guaranteed to be non-null (if no real
93   // type was provided, this takes on type()'s value)
real_typeArgument94   const TypePtr& real_type() const {
95     return real_type_;
96   }
NArgument97   std::optional<int32_t> N() const {
98     return N_;
99   }
default_valueArgument100   const std::optional<IValue>& default_value() const {
101     return default_value_;
102   }
kwarg_onlyArgument103   bool kwarg_only() const {
104     return kwarg_only_;
105   }
106 
is_outArgument107   bool is_out() const {
108     return is_out_;
109   }
110 
alias_infoArgument111   C10_NODISCARD const AliasInfo* alias_info() const {
112     return alias_info_.get();
113   }
114 
is_inferred_typeArgument115   bool is_inferred_type() const {
116     bool is_inferred_type = false;
117     TORCH_INTERNAL_ASSERT(type_);
118     if (auto pt = type_->cast<TensorType>()) {
119       if (pt->isInferredType()) {
120         is_inferred_type = true;
121       }
122     }
123     return is_inferred_type;
124   }
125 
formatTypeMismatchMsgArgument126   std::string formatTypeMismatchMsg(const std::string& actual_type) const {
127     std::string inferred_type_hint;
128     if (is_inferred_type()) {
129       inferred_type_hint = c10::str(
130           "Inferred '",
131           name(),
132           "' to be of type 'Tensor' ",
133           "because it was not annotated with an explicit type.\n");
134     }
135     return c10::str(
136         "Expected a value of type '",
137         type()->repr_str(),
138         "' for argument '",
139         name(),
140         "' but instead found type '",
141         actual_type,
142         "'.\n",
143         inferred_type_hint);
144   }
145 
cloneWithTypeArgument146   Argument cloneWithType(const TypePtr& new_type) const {
147     return Argument(
148         name_,
149         new_type,
150         N_,
151         default_value_,
152         kwarg_only_,
153         alias_info_ ? std::optional<AliasInfo>(*alias_info_) : std::nullopt);
154   }
155 
156   // this function checks whether this Argument is backward compatible with
157   // the old one. we consider the following cases are backward compatible:
158   //   1) two arguments are equal
159   //   2) this arg's type should be subtype of old
160   //   3) this arg must provide the same default value if old arg has one,
161   bool isBackwardCompatibleWith(
162       const Argument& old,
163       std::ostream* why_not=nullptr) const;
164 
165   // this function checks whether this Argument is forward compatible with
166   // the old one. we consider the following cases are forward compatible:
167   //   1) two arguments are equal
168   //   2) this arg's type should be subtype of old
169   //   3) this arg must provide the same default value if old arg has one,
170   bool isForwardCompatibleWith(
171       const Argument& old,
172       std::ostream* why_not = nullptr) const;
173 
174  private:
175   std::string name_;
176   TypePtr type_;
177   TypePtr real_type_; // this is ScalarType, not int, e.g.
178   // for list types, an optional statically known length for the list
179   // e.g. for int[3]: type = ListType::ofInts(), N = 3
180   // If present, this will allow scalars to be broadcast to this length to
181   // become a list.
182   std::optional<int32_t> N_;
183 
184   std::optional<IValue> default_value_;
185   // AliasInfo is huge, so let's only allocate memory for it if
186   // necessary (which it isn't during schema parsing on startup, to
187   // give a pertinent example).
188   std::unique_ptr<AliasInfo> alias_info_;
189   // is this only specifiable as a keyword argument?
190   bool kwarg_only_;
191   // marks if the argument is out variant of the schema
192   bool is_out_;
193 };
194 
195 inline bool operator==(const Argument& lhs, const Argument& rhs) {
196   return lhs.name() == rhs.name()
197           && *lhs.type() == *rhs.type()
198           && lhs.N() == rhs.N()
199           && lhs.default_value() == rhs.default_value()
200           && lhs.kwarg_only() == rhs.kwarg_only()
201           && (lhs.alias_info() == rhs.alias_info()
202               || (lhs.alias_info() != nullptr && rhs.alias_info() != nullptr
203                    && *lhs.alias_info() == *rhs.alias_info()));
204 }
205 
206 inline bool operator!=(const Argument& lhs, const Argument& rhs) {
207   return !(lhs == rhs);
208 }
209 
210 enum struct TORCH_API SchemaArgType { input, output };
211 
212 /**
213  * struct SchemaArgument
214  *
215  * Structure used to represent arguments or returns for a schema.
216  */
217 struct TORCH_API SchemaArgument {
218   SchemaArgType type;
219   size_t index;
SchemaArgumentSchemaArgument220   SchemaArgument(SchemaArgType tpe, size_t idx) : type(tpe), index(idx) {}
221   bool operator==(const SchemaArgument& rhs) const {
222     return type == rhs.type && index == rhs.index;
223   }
224 };
225 
226 bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs);
227 
228 struct TORCH_API FunctionSchema {
229   FunctionSchema(
230       std::string name,
231       std::string overload_name,
232       std::vector<Argument> arguments,
233       std::vector<Argument> returns,
234       bool is_vararg = false,
235       bool is_varret = false)
236       : name_({std::move(name), std::move(overload_name)}),
237         arguments_(std::move(arguments)),
238         returns_(std::move(returns)),
239         is_vararg_(is_vararg),
240         is_varret_(is_varret) {
241     checkSchema();
242   }
243 
244   FunctionSchema(
245       Symbol name,
246       std::string overload_name,
247       std::vector<Argument> arguments,
248       std::vector<Argument> returns,
249       bool is_vararg = false,
250       bool is_varret = false)
251       : FunctionSchema(
252             name.toQualString(),
253             std::move(overload_name),
254             std::move(arguments),
255             std::move(returns),
256             is_vararg,
257             is_varret) {
258     checkSchema();
259   }
260 
261   // Checks whether this schema is backward compatible with the old one.
262   // The following conditions must be true:
263   // [Function structure] The new schema's name, overload-name, varargs, and
264   //      return arity are the same.
265   // [Output Narrowing] The new schema's output type must be the same class
266   //      or inherit from the old schema's output type.
267   // [Argument count] The new schema must have at least as many arguments as
268   //      the old schema (considering the list of positional and kwargs).
269   // [Arg Compatibility] Every argument in the old schema has a corresponding
270   //      argument in the new schema that:
271   //        * is at the same position.
272   //        * has the same name.
273   //        * is either positional, or kwarg and the old argument was kwarg.
274   //        * has the same type, or the old argument's type inherits from the
275   //          new argument's type.
276   // [Default Values] Every new argument must have a default value.
277   // E.g.
278   //   OK    f_new(a, b, c=1) => f_old(a, b)
279   //   NOK   f_new(a, c=1, *, b) => f_old(a, *, b)
280   //   OK    f_new(a, b, *, c) => f_old(a, *, b, c)
281   //   NOK   f_new(a, *, b, c) -> f_old(a, b, *, c)
282   //   NOK   f_new(a, *, c, b) => f_old(a, *, b, c)
283   //   OK    f_new(a, *, b, c, d=1) => f_old(a, *, b, c)
284   bool isBackwardCompatibleWith(
285       const FunctionSchema& old,
286       std::ostream* why_not = nullptr) const;
287 
288   // Checks whether this schema is forward compatible with the old one.
289   // The following conditions must be true:
290   // [Function structure] The new schema's name, overload-name, varargs, and
291   //      return arity are the same.
292   // [Output Narrowing] The new schema's output type must be the same class
293   //      or inherit from the old schema's output type.
294   // [Arg Compatibility] Every argument in the old schema has a corresponding
295   //      argument in the new schema that:
296   //        * is at the same position.
297   //        * has the same name.
298   //        * is either positional, or kwarg and the old argument was kwarg.
299   //        * has the same type, or the old argument's type inherits from the
300   //          new argument's type.
301   // [Default Values] Every new argument must have a default value.
302   //         Each default value type should NOT be a container type.
303   // [Positioning] All defaults arguments MUST go after either old
304   //         default arguments or the end of positional arguments
305   //         and right BEFORE all out arguments
306   bool isForwardCompatibleWith(
307       const FunctionSchema& old,
308       std::ostringstream& why_not) const;
309 
310  private:
311   OperatorName name_;
312   std::vector<Argument> arguments_;
313   std::vector<Argument> returns_;
314   // if true then this schema takes an arbitrary number of additional arguments
315   // after the argument specified in arguments
316   // currently this is used primarily to represent 'primitive' operators whose
317   // arguments are not checked by schema
318   bool is_vararg_;
319   bool is_varret_;
320 
321   // if no alias information is directly specified, what kind of "default"
322   // alias information should we infer?
323   // NB: due to alias analysis kind merging, this may be nullopt.  Eventually
324   // this should always be set no matter what
325   std::optional<AliasAnalysisKind> alias_kind_;
326 
327   template <typename T>
328   void checkArg(const IValue& value, const Argument& argument, std::optional<size_t> pos) const;
329 
checkSchemaFunctionSchema330   void checkSchema() const {
331     bool seen_default_arg = false;
332     for (const auto& arg : arguments()) {
333       if (arg.default_value()) {
334         seen_default_arg = true;
335       } else {
336         // we have historically serialized broadcasting lists wo/default values,
337         // so to not break BC allow lists here
338         if (arg.type()->kind() == ListType::Kind) {
339           continue;
340         }
341         TORCH_INTERNAL_ASSERT(
342             !seen_default_arg || arg.kwarg_only(),
343             "Non-default positional argument follows default argument. Parameter ",
344             arg.name(),
345             " in ",
346             *this);
347       }
348     }
349   }
350 
351  public:
352 
353   void dump() const;
354 
operator_nameFunctionSchema355   const OperatorName& operator_name() const {
356     return name_;
357   }
nameFunctionSchema358   const std::string& name() const {
359     return name_.name;
360   }
overload_nameFunctionSchema361   const std::string& overload_name() const {
362     return name_.overload_name;
363   }
argumentsFunctionSchema364   const std::vector<Argument>& arguments() const {
365     return arguments_;
366   }
returnsFunctionSchema367   const std::vector<Argument>& returns() const {
368     return returns_;
369   }
is_varargFunctionSchema370   bool is_vararg() const {
371     return is_vararg_;
372   }
is_varretFunctionSchema373   bool is_varret() const {
374     return is_varret_;
375   }
is_aliasingFunctionSchema376   bool is_aliasing(const c10::SchemaArgument &argument) const {
377     TORCH_INTERNAL_ASSERT(
378     argument.index < getCorrectList(argument.type).size(),
379     "Invalid index for schema.");
380     const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info();
381     return aliasInfo;
382   }
is_mutableFunctionSchema383   bool is_mutable() const {
384     return std::any_of(
385         arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) {
386           const AliasInfo* aliasInfo = arg.alias_info();
387           return aliasInfo && aliasInfo->isWrite();
388         });
389   }
is_mutableFunctionSchema390   bool is_mutable(const c10::SchemaArgument &argument) const {
391     TORCH_INTERNAL_ASSERT(
392         argument.index < getCorrectList(argument.type).size(),
393         "Invalid index for schema.");
394     const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info();
395     return aliasInfo && aliasInfo->isWrite();
396   }
is_mutableFunctionSchema397   bool is_mutable(c10::string_view name) const {
398     std::optional<int> index = argumentIndexWithName(name);
399     TORCH_INTERNAL_ASSERT(
400         index != std::nullopt, "Schema has no argument named ", name);
401 
402     return is_mutable({c10::SchemaArgType::input, static_cast<size_t>(*index)});
403   }
404 
405   // Returns whether lhs and rhs may alias directly.
406   // This does not account for cases where lhs or rhs are a container that
407   // may contain elements that alias the other argument.
408   // FunctionSchema::may_contain_alias will include that functionality.
409   bool may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const;
410 
411   // Returns whether lhs and rhs may alias directly or whether lhs/rhs are a container
412   // that may contain elements that alias the other argument.
413   // bidirectional = false only returns whether lhs may contain an alias of rhs
414   // while bidirectional = true returns both directions.
415   bool may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional = true) const;
416 
417   // Returns whether the two AliasTypeSets contain any similarities
418   // ie: whether the two type sets can alias.
419   bool canAliasTypeSetsAlias(const std::optional<AliasTypeSet> &lhs, const std::optional<AliasTypeSet> &rhs) const;
420 
421   // Recursively Finds all contained types within the AliasTypeSet.
422   std::optional<AliasTypeSet> getAliasTypeSetContainedTypes(const std::optional<AliasTypeSet> &aliasTypeSet) const;
423 
424   // Similar to mapTypeToAliasTypeSet defined in alias_analysis.cpp.
425   // Used to map types to a type such that all types that can alias will be mapped to the same type.
426   // For example, calling this method on 'Optional[List[int]]' is the same as calling this method
427   // on 'List[int]'.
428   std::optional<AliasTypeSet> mapTypeToAliasTypeSet(const TypePtr& type) const;
429 
430   // Returns either arguments() or returns() depending on the SchemaArgType
431   // output => returns(), input => arguments()
432   const std::vector<Argument>& getCorrectList(SchemaArgType type) const;
433 
argumentIndexWithNameFunctionSchema434   std::optional<int> argumentIndexWithName(c10::string_view name) const {
435     for (const auto i : c10::irange(arguments().size())) {
436       if(name == arguments()[i].name())
437         return i;
438     }
439     return std::nullopt;
440   }
cloneWithNameFunctionSchema441   FunctionSchema cloneWithName(std::string name, std::string overload_name) const {
442     return FunctionSchema(
443         std::move(name),
444         std::move(overload_name),
445         arguments(),
446         returns(),
447         is_vararg(),
448         is_varret()
449         );
450   }
cloneWithArgumentsFunctionSchema451   FunctionSchema cloneWithArguments(std::vector<Argument> new_arguments) const {
452     return FunctionSchema(
453         name(),
454         overload_name(),
455         std::move(new_arguments),
456         returns(),
457         is_vararg(),
458         is_varret());
459   }
cloneWithReturnsFunctionSchema460   FunctionSchema cloneWithReturns(std::vector<Argument> new_returns) const {
461     return FunctionSchema(
462         name(),
463         overload_name(),
464         arguments(),
465         std::move(new_returns),
466         is_vararg(),
467         is_varret());
468   }
469 
470   std::string formatTypeMismatchMsg(
471       const Argument& expected,
472       const std::string& actual_type,
473       std::optional<size_t> position = std::nullopt,
474       std::optional<std::string> value = std::nullopt) const;
475 
476   FunctionSchema cloneWithRemappedTypes(
477       const std::function<TypePtr(TypePtr)> type_map) const;
478 
479   FunctionSchema cloneWithRealTypes(bool with_symint=true) const;
480 
481   // Check that inputs have the correct types and appends any missing default
482   // values.
483   template <typename T = c10::PlatformType>
484   void checkAndNormalizeInputs(
485       std::vector<IValue>& inputs,
486       const std::unordered_map<std::string, IValue>& kwargs =
487           std::unordered_map<std::string, IValue>{}) const;
488 
489   std::string findErrorInKwargs(const std::vector<std::string>& kwargs) const;
490 
hasAnyAliasInfoFunctionSchema491   bool hasAnyAliasInfo() const {
492     for (const auto& arg : arguments_) {
493       if (arg.alias_info() != nullptr) {
494         return true;
495       }
496     }
497     for (const auto& ret : returns_) {
498       if (ret.alias_info() != nullptr) {
499         return true;
500       }
501     }
502     return false;
503   }
504 
505 
506   // TODO remove the mutation here
isDefaultAliasAnalysisKindFunctionSchema507   bool isDefaultAliasAnalysisKind() const {
508     return !alias_kind_;
509   }
aliasAnalysisFunctionSchema510   AliasAnalysisKind aliasAnalysis() const {
511     return alias_kind_.value_or(AliasAnalysisKind::CONSERVATIVE);
512   }
setAliasAnalysisFunctionSchema513   void setAliasAnalysis(AliasAnalysisKind v) {
514     alias_kind_ = v;
515   }
516 
getNamespaceFunctionSchema517   std::optional<c10::string_view> getNamespace() const {
518     return name_.getNamespace();
519   }
520 
521   // Returns true if we successfully set the namespace (as there
522   // was none set, and false otherwise)
setNamespaceIfNotSetFunctionSchema523   bool setNamespaceIfNotSet(const char* ns) {
524     return name_.setNamespaceIfNotSet(ns);
525   }
526 
527   // can a function with this schema be substituted for a function of rhs's
528   // schema and have the program typecheck?
529   // as_method - if true, treat this schema as a method and ignore
530   // the first argument, which will be the object in both cases
531   bool isSubtypeOf(const FunctionSchema& rhs, bool as_method, std::ostream* why_not=nullptr) const;
532 };
533 
534 inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) {
535   return lhs.name() == rhs.name()
536      && lhs.overload_name() == rhs.overload_name()
537      && lhs.arguments() == rhs.arguments()
538      && lhs.returns() == rhs.returns()
539      && lhs.is_vararg() == rhs.is_vararg()
540      && lhs.is_varret() == rhs.is_varret();
541 }
542 
543 inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) {
544   return !(lhs == rhs);
545 }
546 
547 // print out Argument, which is compatible with FunctionSchema parser
548 // full format: Type(alias)? name=default_value
549 inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
550 
551   // for adjusting the ? position.
552   // in schema, we have Tensor?(a!) input, and t(a!)?.
553   // however, t?(a!) doesn't work with schema parser.
554   // so we always use Type(alias)? format
555   // real_type versus fake_type: in order to be compatible with FunctionSchema
556   // parser, printing an argument with either MemoryFormat or Layout type should
557   // give us the original schema string, hence printing out real_type.
558   auto type = arg.real_type();
559   bool is_opt = type->kind() == OptionalType::Kind;
560   auto unopt_type = is_opt ? type->castRaw<OptionalType>()->getElementType() : type;
561 
562   if (unopt_type->kind() == ListType::Kind) {
563     // sized lists get size N from arg, not type
564     auto list = unopt_type->cast<c10::ListType>();
565     out << list->getElementType()->str();
566     if (arg.alias_info() && !arg.alias_info()->containedTypes().empty()){
567       out << arg.alias_info()->containedTypes()[0];
568     }
569     std::string N = "";
570     if (arg.N()) {
571         N = std::to_string(*arg.N());
572     }
573     out << "[" << N << "]";
574   } else {
575     out << unopt_type->str();
576   }
577 
578   // print alias info if it has beforeSets.
579   if (arg.alias_info() && !arg.alias_info()->beforeSets().empty()) {
580     out << *arg.alias_info();
581   }
582 
583   if (is_opt) {
584     out << "?";
585   }
586 
587   if (!arg.name().empty()) {
588     out << " " << arg.name();
589   }
590 
591   if (arg.default_value()) {
592     out << "=";
593     if ((type->kind() == c10::TypeKind::StringType ||
594         unopt_type->kind() == c10::TypeKind::StringType) &&
595         arg.default_value().value().isString()) {
596       printQuotedString(out, arg.default_value().value().toStringRef());
597     } else if (type->kind() == TypeKind::ListType && type->castRaw<ListType>()->getElementType()->kind() == c10::TypeKind::IntType) {
598       // We want to faithfully replicate JIT schema.
599       // in native_functions.yaml defaults for int arrays with a single value always look like
600       //   int[2] stride=1
601       // instead of
602       //   int[2] stride=[1, 1]
603       auto default_val = arg.default_value().value().toIntList();
604       if (default_val.size() > 1) {
605         auto all_defaults_the_same = true;
606         for (const auto i : c10::irange(1, default_val.size())) {
607           if (default_val[0] != default_val[i]) all_defaults_the_same = false;
608         }
609         if (all_defaults_the_same) {
610           out << default_val[0];
611         } else {
612           out << arg.default_value().value();
613         }
614       } else {
615         out << arg.default_value().value();
616       }
617     } else {
618       out << arg.default_value().value();
619     }
620   }
621 
622   return out;
623 }
624 
625 TORCH_API std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema);
626 
toString(const FunctionSchema & schema)627 inline std::string toString(const FunctionSchema& schema) {
628   std::ostringstream str;
629   str << schema;
630   return str.str();
631 }
632 
633 } // namespace c10
634 
635 namespace std {
636 template<>
637   struct hash<c10::SchemaArgument> {
638     size_t operator()(const c10::SchemaArgument& arg) const
639     {
640       return c10::hash_combine(std::hash<size_t>()(arg.index), std::hash<size_t>()(static_cast<std::size_t>(arg.type)));
641     }
642   };
643 template<>
644   struct hash<c10::Argument> {
645     size_t operator()(const c10::Argument& arg) const
646     {
647       auto hash = std::hash<std::string>{}(arg.name());
648       auto type_hash = std::hash<c10::TypePtr>{}(arg.type());
649       auto kwarg_only_hash = std::hash<bool>{}(arg.kwarg_only());
650       hash = c10::hash_combine(hash, type_hash);
651       hash = c10::hash_combine(hash, kwarg_only_hash);
652       // hashing optional fields if they exist
653       if (arg.default_value()) {
654         auto default_value_hash = c10::hash<c10::IValue>{}(arg.default_value().value());
655         hash = c10::hash_combine(hash, default_value_hash);
656       }
657       if (arg.N()) {
658         auto N_hash = std::hash<int64_t>{}(*arg.N());
659         hash = c10::hash_combine(hash, N_hash);
660       }
661       if (arg.alias_info()) {
662         auto alias_info_hash = std::hash<c10::AliasInfo>{}(*arg.alias_info());
663         hash = c10::hash_combine(hash, alias_info_hash);
664       }
665       return hash;
666     }
667   };
668 template<>
669   struct hash<c10::FunctionSchema> {
670     size_t operator()(const c10::FunctionSchema& schema) const
671     {
672       auto hash = std::hash<c10::OperatorName>{}(schema.operator_name());
673       auto args_hash = c10::hash<std::vector<c10::Argument>>{}(schema.arguments());
674       auto returns_hash = c10::hash<std::vector<c10::Argument>>{}(schema.returns());
675       auto is_vararg_hash = std::hash<bool>{}(schema.is_vararg());
676       auto is_varret_hash = std::hash<bool>{}(schema.is_varret());
677       hash = c10::hash_combine(hash, args_hash);
678       hash = c10::hash_combine(hash, returns_hash);
679       hash = c10::hash_combine(hash, is_vararg_hash);
680       hash = c10::hash_combine(hash, is_varret_hash);
681       return hash;
682     }
683   };
684 } // namespace std
685 
686 
687 #include <ATen/core/function_schema_inl.h>  // IWYU pragma: keep
688