xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/operator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // in memory description of all ATen Ops similar to Caffe2 schema
2 // once C10 exists this can be removed, or stubbed out, but we need
3 // it now to implement correct semantic checking for script
4 #pragma once
5 
6 #include <ATen/core/dispatch/Dispatcher.h>
7 #include <ATen/core/dispatch/OperatorOptions.h>
8 #include <ATen/core/op_registration/op_allowlist.h>
9 #include <ATen/core/stack.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/overloaded.h>
12 #include <torch/csrc/jit/frontend/function_schema_parser.h>
13 #include <torch/csrc/jit/runtime/operator_options.h>
14 #include <torch/library.h>
15 
16 #include <ATen/core/function_schema.h>
17 #include <ATen/core/symbol.h>
18 
19 #include <functional>
20 #include <initializer_list>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <variant>
26 #include <vector>
27 
28 namespace torch::jit {
29 
30 struct Node;
31 using ::c10::Argument;
32 using ::c10::FunctionSchema;
33 using ::c10::Symbol;
34 
35 using OperationCreator = Operation (*)(const Node*);
36 
37 namespace {
38 const std::array<at::Tag, 1> kJitOnlyOperatorTags = {
39     at::Tag::pt2_compliant_tag};
40 }
41 
42 /*
43  * Note: JIT relies on Operator instances having static lifetime, because
44  * it for example stores a non-owning FunctionSchema* pointer in the Node class,
45  * which points to the function schema stored in the Operator instance.
46  * Also, jit::Operator is meant to store more operator related information like
47  * symbolic derivatives, which also requires them to have static lifetime
48  * so that changes to symbolic derivatives are remembered.
49  *
50  * Currently, the JIT operator library contains a jit::Operator instance
51  * with a wrapper for each c10 operator. The c10 operator library registers
52  * those wrappers using listeners in register_c10_ops.cpp.
53  * TODO Instead of doing it this way, we should only have pure-jit ops in
54  * the jit library but have the JIT operator lookup look into the c10 library
55  * too.
56  */
57 
58 // An Operator is a thin wrapper around either a pure JIT operator (e.g. prim
59 // ops) or a c10 operator, allowing some common operations and abstracting away
60 // the concrete operator nature.
61 struct TORCH_API Operator {
62  private:
63   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
64   struct C10Operator final {
65     c10::OperatorHandle handle_;
66     Operation op_;
67   };
68   struct UnparsedFunctionSchema final {
69     std::string schema_string_;
70     mutable std::optional<c10::AliasAnalysisKind> alias_analysis_;
71   };
72   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
73   struct JitOnlyOperator final {
74     // The only valid transition for schema_ is from right->left, i.e.
75     // when the schema gets parsed.
76     mutable std::variant<FunctionSchema, UnparsedFunctionSchema> schema_;
77 
78     std::variant<Operation, OperationCreator> op_;
79   };
80 
81  public:
OperatorOperator82   Operator(c10::OperatorHandle opHandle, Operation operation)
83       : op_(C10Operator{std::move(opHandle), std::move(operation)}) {}
84 
OperatorOperator85   Operator(
86       std::string schema,
87       Operation op,
88       c10::AliasAnalysisKind alias_analysis)
89       : op_(JitOnlyOperator{
90             UnparsedFunctionSchema{std::move(schema), alias_analysis},
91             Operation(std::move(op))}) {}
92 
OperatorOperator93   Operator(
94       std::string name,
95       std::string overload_name,
96       std::vector<Argument> arguments,
97       std::vector<Argument> returns,
98       Operation op,
99       c10::AliasAnalysisKind alias_analysis)
100       : op_(JitOnlyOperator{
101             FunctionSchema(varArgSchemaWithName(
102                 std::move(name),
103                 std::move(overload_name),
104                 std::move(arguments),
105                 std::move(returns),
106                 alias_analysis)),
107             std::move(op)}) {}
108 
OperatorOperator109   Operator(
110       std::string schema,
111       OperationCreator op_creator,
112       c10::AliasAnalysisKind alias_analysis)
113       : op_(JitOnlyOperator{
114             UnparsedFunctionSchema{std::move(schema), alias_analysis},
115             op_creator}) {}
116 
117   // Helper constructor to register `op` to run
118   // run for _every_ IR Node where n.kind() == name, regardless of arguments.
119   // This is accomplished by marking the schema varargs and having no required
120   // arguments.
OperatorOperator121   Operator(
122       Symbol name,
123       OperationCreator op_creator,
124       c10::AliasAnalysisKind alias_analysis)
125       : op_(JitOnlyOperator{
126             FunctionSchema(varArgSchemaWithName(name, alias_analysis)),
127             op_creator}) {}
128 
129   Operation getOperation(const Node* node = nullptr) const {
130     return std::visit(
131         c10::overloaded(
132             [](const C10Operator& op) { return op.op_; },
133             [node](const JitOnlyOperator& op) {
134               return std::visit(
135                   c10::overloaded(
136                       [](const Operation& op) { return op; },
137                       [node](const OperationCreator& op_creator) {
138                         return op_creator(node);
139                       }),
140                   op.op_);
141             }),
142         op_);
143   }
144 
getOperationForDispatchKeyOperator145   Operation getOperationForDispatchKey(c10::DispatchKey dk) const {
146     // TODO: some sort of caching mechanism?
147     return std::visit(
148         c10::overloaded(
149             [dk](const C10Operator& op) {
150               return Operation([op, dk](Stack& stack) {
151                 op.handle_.callBoxedForDispatchKey(dk, stack);
152               });
153             },
154             [](const JitOnlyOperator& op) {
155               TORCH_CHECK(
156                   false,
157                   "calling a JIT operator for dispatch key is not supported");
158               return Operation(nullptr);
159             }),
160         op_);
161   }
162 
schemaOperator163   const FunctionSchema& schema() const {
164     return std::visit(
165         c10::overloaded(
166             [](const C10Operator& op) -> const FunctionSchema& {
167               return op.handle_.schema();
168             },
169             [](const JitOnlyOperator& op) -> const FunctionSchema& {
170               // we lazily parse schema initialized from strings so that
171               // we do less work during static operator registration
172               if (op.schema_.index() == 1) {
173                 auto& unmaterializedSchema =
174                     std::get<UnparsedFunctionSchema>(op.schema_);
175                 FunctionSchema schema =
176                     parseSchema(unmaterializedSchema.schema_string_);
177                 if (unmaterializedSchema.alias_analysis_.has_value()) {
178                   // TODO What if it gets set later?
179                   schema.setAliasAnalysis(
180                       *unmaterializedSchema.alias_analysis_);
181                 }
182                 op.schema_ = std::move(schema);
183               }
184               return std::get<FunctionSchema>(op.schema_);
185             }),
186         op_);
187   }
188 
getTagsOperator189   c10::ArrayRef<at::Tag> getTags() const {
190     return std::visit(
191         c10::overloaded(
192             [](const C10Operator& op) { return op.handle_.getTags(); },
193             [](const JitOnlyOperator& op) {
194               // JitOnlyOperators don't have an c10::OperatorHandle or a way to
195               // specify tags. We're grandfathering them all into
196               // pt2_compliant_tag, but for anything else, please just stop
197               // using JitOnlyOperator.
198               return c10::ArrayRef<at::Tag>(kJitOnlyOperatorTags);
199             }),
200         op_);
201   }
202 
isC10OpOperator203   bool isC10Op() const {
204     return op_.index() == 0;
205   }
206 
aliasAnalysisKindOperator207   c10::AliasAnalysisKind aliasAnalysisKind() const {
208     const FunctionSchema& schemaRef = schema();
209     c10::AliasAnalysisKind alias_analysis = schemaRef.aliasAnalysis();
210 
211     TORCH_CHECK(
212         alias_analysis == AliasAnalysisKind::FROM_SCHEMA ||
213             !schemaRef.hasAnyAliasInfo(),
214         "In operator registration: Tried to register operator ",
215         schemaRef,
216         " with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA.");
217     return alias_analysis;
218   }
219 
hasOperationOperator220   bool hasOperation() const {
221     return std::visit(
222         c10::overloaded(
223             [](const C10Operator&) { return true; },
224             [](const JitOnlyOperator& op) { return op.op_.index() == 0; }),
225         op_);
226   }
227 
228  private:
varArgSchemaWithNameOperator229   static FunctionSchema varArgSchemaWithName(
230       Symbol name,
231       AliasAnalysisKind alias_analysis) {
232     auto result = FunctionSchema(
233         name,
234         "",
235         {},
236         {},
237         /*is_vararg*/ true,
238         /*is_varret*/ true);
239     result.setAliasAnalysis(alias_analysis);
240     return result;
241   }
242 
varArgSchemaWithNameOperator243   static FunctionSchema varArgSchemaWithName(
244       std::string name,
245       std::string overload_name,
246       std::vector<Argument> arguments,
247       std::vector<Argument> returns,
248       AliasAnalysisKind alias_analysis) {
249     auto result = FunctionSchema(
250         std::move(name),
251         std::move(overload_name),
252         std::move(arguments),
253         std::move(returns),
254         /*is_vararg*/ false,
255         /*is_varret*/ false);
256     result.setAliasAnalysis(alias_analysis);
257     return result;
258   }
259 
260   std::variant<C10Operator, JitOnlyOperator> op_;
261 };
262 
263 TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
264 
265 TORCH_API const std::vector<std::shared_ptr<Operator>> getAllOperators();
266 TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(
267     Symbol name);
268 // Returns operators in the order which OpOverloadPacket resolves them.
269 TORCH_API std::vector<std::shared_ptr<Operator>> getAllSortedOperatorsFor(
270     Symbol name);
271 
272 // given a operator with an overload name, find the specific operator related to
273 // it, may return nullptr if no operator exists.
274 TORCH_API std::shared_ptr<Operator> findOperatorFor(
275     const c10::OperatorName& full_name);
276 
277 TORCH_API std::vector<Symbol> findSimilarOperators(Symbol input_op);
278 
279 TORCH_API void registerOperator(Operator&& op);
280 TORCH_API void deregisterOperator(const FunctionSchema& schema);
281 
282 // XXX: this function is meant to be used with string literals only!
283 TORCH_API std::shared_ptr<Operator> getOperatorForLiteral(
284     const char* signature);
285 
286 // Ensure the thing that registers c10 ops is defined.
287 // Otherwise, our registry will not have c10 ops. You can run into this
288 // scenario if you're querying registered ops during static init.
289 //
290 // This fn is defined in register_c10_ops.cpp
291 TORCH_API void ensure_c10_registerer_defined();
292 
293 // Used to assert that unschematized operators have an analysis method written
294 TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym);
295 
296 // A factory function to generate an optional operator. It has two
297 // instantiations depending on the template bool arg value. The arg can be a
298 // compile-time function for the selective op registration based on schema
299 // string.
300 template <typename Func>
OperatorGenerator(const char * schema_str,Func && op,AliasAnalysisKind alias_analysis)301 std::optional<Operator> OperatorGenerator(
302     const char* schema_str,
303     Func&& op,
304     AliasAnalysisKind alias_analysis) {
305   return std::optional<Operator>(Operator(
306       std::string(schema_str), std::forward<Func>(op), alias_analysis));
307 }
308 
309 template <typename Func>
OperatorGenerator(torch::detail::SelectiveStr<true> schema_str,Func && op,AliasAnalysisKind alias_analysis)310 std::optional<Operator> OperatorGenerator(
311     torch::detail::SelectiveStr<true> schema_str,
312     Func&& op,
313     AliasAnalysisKind alias_analysis) {
314   return OperatorGenerator(
315       static_cast<const char*>(schema_str),
316       std::forward<Func>(op),
317       alias_analysis);
318 }
319 
320 template <typename Func>
OperatorGenerator(torch::detail::SelectiveStr<false> schema_str,Func && op,AliasAnalysisKind alias_analysis)321 std::optional<Operator> OperatorGenerator(
322     torch::detail::SelectiveStr<false> schema_str,
323     Func&& op,
324     AliasAnalysisKind alias_analysis) {
325   return std::nullopt;
326 }
327 
328 template <typename Func>
OperatorGenerator(const std::string name,const std::string overload_name,const std::vector<c10::Argument> arguments,const std::vector<c10::Argument> returns,Func && op,AliasAnalysisKind alias_analysis)329 std::optional<Operator> OperatorGenerator(
330     const std::string name,
331     const std::string overload_name,
332     const std::vector<c10::Argument> arguments,
333     const std::vector<c10::Argument> returns,
334     Func&& op,
335     AliasAnalysisKind alias_analysis) {
336   return std::optional<Operator>(Operator(
337       name,
338       overload_name,
339       arguments,
340       returns,
341       std::forward<Func>(op),
342       alias_analysis));
343 }
344 
345 } // namespace torch::jit
346