1 // in memory description of all ATen Ops similar to Caffe2 schema
2 // once C10 exists this can be removed, or stubbed out, but we need
3 // it now to implement correct semantic checking for script
4 #pragma once
5
6 #include <ATen/core/dispatch/Dispatcher.h>
7 #include <ATen/core/dispatch/OperatorOptions.h>
8 #include <ATen/core/op_registration/op_allowlist.h>
9 #include <ATen/core/stack.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/overloaded.h>
12 #include <torch/csrc/jit/frontend/function_schema_parser.h>
13 #include <torch/csrc/jit/runtime/operator_options.h>
14 #include <torch/library.h>
15
16 #include <ATen/core/function_schema.h>
17 #include <ATen/core/symbol.h>
18
19 #include <functional>
20 #include <initializer_list>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <variant>
26 #include <vector>
27
28 namespace torch::jit {
29
30 struct Node;
31 using ::c10::Argument;
32 using ::c10::FunctionSchema;
33 using ::c10::Symbol;
34
35 using OperationCreator = Operation (*)(const Node*);
36
37 namespace {
38 const std::array<at::Tag, 1> kJitOnlyOperatorTags = {
39 at::Tag::pt2_compliant_tag};
40 }
41
42 /*
43 * Note: JIT relies on Operator instances having static lifetime, because
44 * it for example stores a non-owning FunctionSchema* pointer in the Node class,
45 * which points to the function schema stored in the Operator instance.
46 * Also, jit::Operator is meant to store more operator related information like
47 * symbolic derivatives, which also requires them to have static lifetime
48 * so that changes to symbolic derivatives are remembered.
49 *
50 * Currently, the JIT operator library contains a jit::Operator instance
51 * with a wrapper for each c10 operator. The c10 operator library registers
52 * those wrappers using listeners in register_c10_ops.cpp.
53 * TODO Instead of doing it this way, we should only have pure-jit ops in
54 * the jit library but have the JIT operator lookup look into the c10 library
55 * too.
56 */
57
58 // An Operator is a thin wrapper around either a pure JIT operator (e.g. prim
59 // ops) or a c10 operator, allowing some common operations and abstracting away
60 // the concrete operator nature.
61 struct TORCH_API Operator {
62 private:
63 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
64 struct C10Operator final {
65 c10::OperatorHandle handle_;
66 Operation op_;
67 };
68 struct UnparsedFunctionSchema final {
69 std::string schema_string_;
70 mutable std::optional<c10::AliasAnalysisKind> alias_analysis_;
71 };
72 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
73 struct JitOnlyOperator final {
74 // The only valid transition for schema_ is from right->left, i.e.
75 // when the schema gets parsed.
76 mutable std::variant<FunctionSchema, UnparsedFunctionSchema> schema_;
77
78 std::variant<Operation, OperationCreator> op_;
79 };
80
81 public:
OperatorOperator82 Operator(c10::OperatorHandle opHandle, Operation operation)
83 : op_(C10Operator{std::move(opHandle), std::move(operation)}) {}
84
OperatorOperator85 Operator(
86 std::string schema,
87 Operation op,
88 c10::AliasAnalysisKind alias_analysis)
89 : op_(JitOnlyOperator{
90 UnparsedFunctionSchema{std::move(schema), alias_analysis},
91 Operation(std::move(op))}) {}
92
OperatorOperator93 Operator(
94 std::string name,
95 std::string overload_name,
96 std::vector<Argument> arguments,
97 std::vector<Argument> returns,
98 Operation op,
99 c10::AliasAnalysisKind alias_analysis)
100 : op_(JitOnlyOperator{
101 FunctionSchema(varArgSchemaWithName(
102 std::move(name),
103 std::move(overload_name),
104 std::move(arguments),
105 std::move(returns),
106 alias_analysis)),
107 std::move(op)}) {}
108
OperatorOperator109 Operator(
110 std::string schema,
111 OperationCreator op_creator,
112 c10::AliasAnalysisKind alias_analysis)
113 : op_(JitOnlyOperator{
114 UnparsedFunctionSchema{std::move(schema), alias_analysis},
115 op_creator}) {}
116
117 // Helper constructor to register `op` to run
118 // run for _every_ IR Node where n.kind() == name, regardless of arguments.
119 // This is accomplished by marking the schema varargs and having no required
120 // arguments.
OperatorOperator121 Operator(
122 Symbol name,
123 OperationCreator op_creator,
124 c10::AliasAnalysisKind alias_analysis)
125 : op_(JitOnlyOperator{
126 FunctionSchema(varArgSchemaWithName(name, alias_analysis)),
127 op_creator}) {}
128
129 Operation getOperation(const Node* node = nullptr) const {
130 return std::visit(
131 c10::overloaded(
132 [](const C10Operator& op) { return op.op_; },
133 [node](const JitOnlyOperator& op) {
134 return std::visit(
135 c10::overloaded(
136 [](const Operation& op) { return op; },
137 [node](const OperationCreator& op_creator) {
138 return op_creator(node);
139 }),
140 op.op_);
141 }),
142 op_);
143 }
144
getOperationForDispatchKeyOperator145 Operation getOperationForDispatchKey(c10::DispatchKey dk) const {
146 // TODO: some sort of caching mechanism?
147 return std::visit(
148 c10::overloaded(
149 [dk](const C10Operator& op) {
150 return Operation([op, dk](Stack& stack) {
151 op.handle_.callBoxedForDispatchKey(dk, stack);
152 });
153 },
154 [](const JitOnlyOperator& op) {
155 TORCH_CHECK(
156 false,
157 "calling a JIT operator for dispatch key is not supported");
158 return Operation(nullptr);
159 }),
160 op_);
161 }
162
schemaOperator163 const FunctionSchema& schema() const {
164 return std::visit(
165 c10::overloaded(
166 [](const C10Operator& op) -> const FunctionSchema& {
167 return op.handle_.schema();
168 },
169 [](const JitOnlyOperator& op) -> const FunctionSchema& {
170 // we lazily parse schema initialized from strings so that
171 // we do less work during static operator registration
172 if (op.schema_.index() == 1) {
173 auto& unmaterializedSchema =
174 std::get<UnparsedFunctionSchema>(op.schema_);
175 FunctionSchema schema =
176 parseSchema(unmaterializedSchema.schema_string_);
177 if (unmaterializedSchema.alias_analysis_.has_value()) {
178 // TODO What if it gets set later?
179 schema.setAliasAnalysis(
180 *unmaterializedSchema.alias_analysis_);
181 }
182 op.schema_ = std::move(schema);
183 }
184 return std::get<FunctionSchema>(op.schema_);
185 }),
186 op_);
187 }
188
getTagsOperator189 c10::ArrayRef<at::Tag> getTags() const {
190 return std::visit(
191 c10::overloaded(
192 [](const C10Operator& op) { return op.handle_.getTags(); },
193 [](const JitOnlyOperator& op) {
194 // JitOnlyOperators don't have an c10::OperatorHandle or a way to
195 // specify tags. We're grandfathering them all into
196 // pt2_compliant_tag, but for anything else, please just stop
197 // using JitOnlyOperator.
198 return c10::ArrayRef<at::Tag>(kJitOnlyOperatorTags);
199 }),
200 op_);
201 }
202
isC10OpOperator203 bool isC10Op() const {
204 return op_.index() == 0;
205 }
206
aliasAnalysisKindOperator207 c10::AliasAnalysisKind aliasAnalysisKind() const {
208 const FunctionSchema& schemaRef = schema();
209 c10::AliasAnalysisKind alias_analysis = schemaRef.aliasAnalysis();
210
211 TORCH_CHECK(
212 alias_analysis == AliasAnalysisKind::FROM_SCHEMA ||
213 !schemaRef.hasAnyAliasInfo(),
214 "In operator registration: Tried to register operator ",
215 schemaRef,
216 " with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA.");
217 return alias_analysis;
218 }
219
hasOperationOperator220 bool hasOperation() const {
221 return std::visit(
222 c10::overloaded(
223 [](const C10Operator&) { return true; },
224 [](const JitOnlyOperator& op) { return op.op_.index() == 0; }),
225 op_);
226 }
227
228 private:
varArgSchemaWithNameOperator229 static FunctionSchema varArgSchemaWithName(
230 Symbol name,
231 AliasAnalysisKind alias_analysis) {
232 auto result = FunctionSchema(
233 name,
234 "",
235 {},
236 {},
237 /*is_vararg*/ true,
238 /*is_varret*/ true);
239 result.setAliasAnalysis(alias_analysis);
240 return result;
241 }
242
varArgSchemaWithNameOperator243 static FunctionSchema varArgSchemaWithName(
244 std::string name,
245 std::string overload_name,
246 std::vector<Argument> arguments,
247 std::vector<Argument> returns,
248 AliasAnalysisKind alias_analysis) {
249 auto result = FunctionSchema(
250 std::move(name),
251 std::move(overload_name),
252 std::move(arguments),
253 std::move(returns),
254 /*is_vararg*/ false,
255 /*is_varret*/ false);
256 result.setAliasAnalysis(alias_analysis);
257 return result;
258 }
259
260 std::variant<C10Operator, JitOnlyOperator> op_;
261 };
262
263 TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
264
265 TORCH_API const std::vector<std::shared_ptr<Operator>> getAllOperators();
266 TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(
267 Symbol name);
268 // Returns operators in the order which OpOverloadPacket resolves them.
269 TORCH_API std::vector<std::shared_ptr<Operator>> getAllSortedOperatorsFor(
270 Symbol name);
271
272 // given a operator with an overload name, find the specific operator related to
273 // it, may return nullptr if no operator exists.
274 TORCH_API std::shared_ptr<Operator> findOperatorFor(
275 const c10::OperatorName& full_name);
276
277 TORCH_API std::vector<Symbol> findSimilarOperators(Symbol input_op);
278
279 TORCH_API void registerOperator(Operator&& op);
280 TORCH_API void deregisterOperator(const FunctionSchema& schema);
281
282 // XXX: this function is meant to be used with string literals only!
283 TORCH_API std::shared_ptr<Operator> getOperatorForLiteral(
284 const char* signature);
285
286 // Ensure the thing that registers c10 ops is defined.
287 // Otherwise, our registry will not have c10 ops. You can run into this
288 // scenario if you're querying registered ops during static init.
289 //
290 // This fn is defined in register_c10_ops.cpp
291 TORCH_API void ensure_c10_registerer_defined();
292
293 // Used to assert that unschematized operators have an analysis method written
294 TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym);
295
296 // A factory function to generate an optional operator. It has two
297 // instantiations depending on the template bool arg value. The arg can be a
298 // compile-time function for the selective op registration based on schema
299 // string.
300 template <typename Func>
OperatorGenerator(const char * schema_str,Func && op,AliasAnalysisKind alias_analysis)301 std::optional<Operator> OperatorGenerator(
302 const char* schema_str,
303 Func&& op,
304 AliasAnalysisKind alias_analysis) {
305 return std::optional<Operator>(Operator(
306 std::string(schema_str), std::forward<Func>(op), alias_analysis));
307 }
308
309 template <typename Func>
OperatorGenerator(torch::detail::SelectiveStr<true> schema_str,Func && op,AliasAnalysisKind alias_analysis)310 std::optional<Operator> OperatorGenerator(
311 torch::detail::SelectiveStr<true> schema_str,
312 Func&& op,
313 AliasAnalysisKind alias_analysis) {
314 return OperatorGenerator(
315 static_cast<const char*>(schema_str),
316 std::forward<Func>(op),
317 alias_analysis);
318 }
319
320 template <typename Func>
OperatorGenerator(torch::detail::SelectiveStr<false> schema_str,Func && op,AliasAnalysisKind alias_analysis)321 std::optional<Operator> OperatorGenerator(
322 torch::detail::SelectiveStr<false> schema_str,
323 Func&& op,
324 AliasAnalysisKind alias_analysis) {
325 return std::nullopt;
326 }
327
328 template <typename Func>
OperatorGenerator(const std::string name,const std::string overload_name,const std::vector<c10::Argument> arguments,const std::vector<c10::Argument> returns,Func && op,AliasAnalysisKind alias_analysis)329 std::optional<Operator> OperatorGenerator(
330 const std::string name,
331 const std::string overload_name,
332 const std::vector<c10::Argument> arguments,
333 const std::vector<c10::Argument> returns,
334 Func&& op,
335 AliasAnalysisKind alias_analysis) {
336 return std::optional<Operator>(Operator(
337 name,
338 overload_name,
339 arguments,
340 returns,
341 std::forward<Func>(op),
342 alias_analysis));
343 }
344
345 } // namespace torch::jit
346