xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/op_registration/op_registration_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /**
2  * This file contains some general registration test cases.
3  * More detailed test cases containing different APIs for registering kernels
4  * are found in other files in this directory.
5  */
6 
7 #include <gtest/gtest.h>
8 
9 // This file intentionally tests some deprecated APIs
10 #pragma GCC diagnostic push
11 #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
12 
13 #include <ATen/core/boxing/impl/test_helpers.h>
14 #include <ATen/core/op_registration/op_registration.h>
15 #include <torch/library.h>
16 #include <ATen/core/Tensor.h>
17 #include <functional>
18 
19 #include <ATen/core/LegacyTypeDispatch.h>
20 
21 #include <algorithm>
22 
23 using c10::RegisterOperators;
24 using c10::OperatorKernel;
25 using c10::OperatorHandle;
26 using c10::Dispatcher;
27 using c10::IValue;
28 using c10::DispatchKey;
29 
30 using torch::Library;
31 using torch::CppFunction;
32 
33 using at::Tensor;
34 
35 namespace {
36 
37 struct DummyKernel final : OperatorKernel {
operator ()__anondbb8b74d0111::DummyKernel38   void operator()(Tensor) {}
39 };
40 
41 struct MockKernel final : OperatorKernel {
MockKernel__anondbb8b74d0111::MockKernel42   MockKernel(bool* called): called_(called) {}
43 
operator ()__anondbb8b74d0111::MockKernel44   void operator()(Tensor) {
45     *called_ = true;
46   }
47 private:
48   bool* called_;
49 };
50 
TEST(OperatorRegistrationTest,whenRegisteringWithSchemaBeforeKernelInOptionsObject_thenCanBeCalled)51 TEST(OperatorRegistrationTest, whenRegisteringWithSchemaBeforeKernelInOptionsObject_thenCanBeCalled) {
52   bool called = false;
53   auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().schema("_test::dummy(Tensor dummy) -> ()").catchAllKernel<MockKernel>(&called));
54 
55   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
56   ASSERT_TRUE(op.has_value());
57   EXPECT_FALSE(called);
58   callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
59   EXPECT_TRUE(called);
60 }
61 
TEST(OperatorRegistrationTest,whenRegisteringWithSchemaAfterKernelInOptionsObject_thenCanBeCalled)62 TEST(OperatorRegistrationTest, whenRegisteringWithSchemaAfterKernelInOptionsObject_thenCanBeCalled) {
63   bool called = false;
64   auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called).schema("_test::dummy(Tensor dummy) -> ()"));
65 
66   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
67   ASSERT_TRUE(op.has_value());
68   EXPECT_FALSE(called);
69   callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
70   EXPECT_TRUE(called);
71 }
72 
TEST(OperatorRegistrationTest,whenRegisteringWithNameBeforeKernelInOptionsObject_thenCanBeCalled)73 TEST(OperatorRegistrationTest, whenRegisteringWithNameBeforeKernelInOptionsObject_thenCanBeCalled) {
74   bool called = false;
75   auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().schema("_test::dummy").catchAllKernel<MockKernel>(&called));
76 
77   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
78   ASSERT_TRUE(op.has_value());
79   EXPECT_FALSE(called);
80   callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
81   EXPECT_TRUE(called);
82 }
83 
TEST(OperatorRegistrationTest,whenRegisteringWithNameAfterKernelInOptionsObject_thenCanBeCalled)84 TEST(OperatorRegistrationTest, whenRegisteringWithNameAfterKernelInOptionsObject_thenCanBeCalled) {
85   bool called = false;
86   auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called).schema("_test::dummy"));
87 
88   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
89   ASSERT_TRUE(op.has_value());
90   EXPECT_FALSE(called);
91   callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
92   EXPECT_TRUE(called);
93 }
94 
TEST(OperatorRegistrationTest,whenRegisteringWithoutSchema_thenFails)95 TEST(OperatorRegistrationTest, whenRegisteringWithoutSchema_thenFails) {
96   expectThrows<c10::Error>([] {
97     c10::RegisterOperators().op(c10::RegisterOperators::options().catchAllKernel<DummyKernel>());
98   }, "In operator registration: Tried to register an operator without specifying a schema or operator name.");
99 }
100 
TEST(OperatorRegistrationTest,whenCallingOpWithWrongDispatchKey_thenFails)101 TEST(OperatorRegistrationTest, whenCallingOpWithWrongDispatchKey_thenFails) {
102   auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::DispatchKey::CPU));
103 
104   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
105   ASSERT_TRUE(op.has_value());
106   expectThrows<c10::Error>([&] {
107     callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
108   }, "Could not run '_test::dummy' with arguments from the 'CUDA'"
109   " backend.");
110 }
111 
TEST(OperatorRegistrationTest,givenOpWithCatchallKernel_whenCallingOp_thenCallsCatchallKernel)112 TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenCallingOp_thenCallsCatchallKernel) {
113   bool called = false;
114   auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
115 
116   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
117   ASSERT_TRUE(op.has_value());
118   EXPECT_FALSE(called);
119   callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
120   EXPECT_TRUE(called);
121 }
122 
123 // TODO Rewrite (since this is now allowed) and reenable
124 // TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenRegisteringDispatchedKernel_thenFails) {
125 //   bool called = false;
126 //   auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
127 //   expectThrows<c10::Error>([&] {
128 //     c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(c10::DispatchKey::CPU, &called));
129 //   }, "for an operator which already has a catch-all kernel registered");
130 // }
131 
132 // TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenRegisteringDispatchedKernelInSameOpCall_thenFails) {
133 //   bool called = false;
134 //   expectThrows<c10::Error>([&] {
135 //     auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
136 //       .catchAllKernel<MockKernel>(&called)
137 //       .kernel<MockKernel>(c10::DispatchKey::CPU, &called));
138 //   }, "for an operator which already has a catch-all kernel registered");
139 // }
140 
TEST(OperatorRegistrationTest,givenOpWithDispatchedKernelOutOfScope_whenRegisteringCatchallKernelAndCallingOp_thenCallsCatchallKernel)141 TEST(OperatorRegistrationTest, givenOpWithDispatchedKernelOutOfScope_whenRegisteringCatchallKernelAndCallingOp_thenCallsCatchallKernel) {
142   bool called = false;
143   {
144     auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(c10::DispatchKey::CPU, &called));
145   }
146 
147   auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
148 
149   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
150   ASSERT_TRUE(op.has_value());
151   EXPECT_FALSE(called);
152   callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
153   EXPECT_TRUE(called);
154 }
155 
156 // TODO Rewrite (since this is now allowed) and reenable
157 // TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernel_thenFails) {
158 //   bool called = false;
159 //   auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(c10::DispatchKey::CPU, &called));
160 //   expectThrows<c10::Error>([&] {
161 //     c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
162 //   }, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPU. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy");
163 // }
164 //
165 // TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernelInSameOpCall_thenFails) {
166 //   bool called = false;
167 //   expectThrows<c10::Error>([&] {
168 //     auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
169 //       .kernel<MockKernel>(c10::DispatchKey::CPU, &called)
170 //       .catchAllKernel<MockKernel>(&called));
171 //   }, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPU. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy");
172 // }
173 
TEST(OperatorRegistrationTest,givenOpWithCatchallKernelOutOfScope_whenRegisteringDispatchedKernelAndCallingOp_thenCallsCatchallKernel)174 TEST(OperatorRegistrationTest, givenOpWithCatchallKernelOutOfScope_whenRegisteringDispatchedKernelAndCallingOp_thenCallsCatchallKernel) {
175   bool called = false;
176   {
177     auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
178   }
179 
180   auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(c10::DispatchKey::CPU, &called));
181 
182   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
183   ASSERT_TRUE(op.has_value());
184   EXPECT_FALSE(called);
185   callOp(*op, dummyTensor(c10::DispatchKey::CPU));
186   EXPECT_TRUE(called);
187 }
188 
TEST(OperatorRegistrationTest,givenOpWithoutKernels_whenRegisteringWithSchema_thenOnlyRegistersSchema)189 TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringWithSchema_thenOnlyRegistersSchema) {
190   auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()");
191 
192   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
193   ASSERT_TRUE(op.has_value()); // assert schema is registered
194   expectThrows<c10::Error>([&] {
195     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
196   }, "Could not run '_test::dummy' with arguments from the 'CPU'"
197   " backend.");
198 }
199 
TEST(OperatorRegistrationTest,givenOpWithoutKernels_whenRegisteringWithoutSchema_thenFails)200 TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringWithoutSchema_thenFails) {
201   expectThrows<c10::Error>([&] {
202     c10::RegisterOperators().op("_test::dummy");
203   }, "Cannot infer operator schema in registration of operator _test::dummy because there is no kernel specified.");
204 }
205 
TEST(OperatorRegistrationTest,givenOpWithoutKernels_whenRunningOutOfScope_thenSchemaIsGone)206 TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRunningOutOfScope_thenSchemaIsGone) {
207   {
208     auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()");
209   }
210 
211   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
212   EXPECT_FALSE(op.has_value());
213 }
214 
TEST(OperatorRegistrationTest,givenOpWithoutKernelsWithoutTensorInputs_whenRegistering_thenRegisters)215 TEST(OperatorRegistrationTest, givenOpWithoutKernelsWithoutTensorInputs_whenRegistering_thenRegisters) {
216   // as long as we don't register non-catchall kernels, ops without tensor arguments are fine
217   auto registrar = c10::RegisterOperators().op("_test::dummy() -> ()");
218 
219   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
220   ASSERT_TRUE(op.has_value()); // assert schema is registered
221 }
222 
TEST(OperatorRegistrationTest,givenMultipleKernelsWithSameDispatchKey_whenRegisteringInSameOpCall_thenFails)223 TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegisteringInSameOpCall_thenFails) {
224   expectThrows<c10::Error>([&] {
225     auto registrar = c10::RegisterOperators()
226         .op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
227             .kernel<DummyKernel>(c10::DispatchKey::CPU)
228             .kernel<DummyKernel>(c10::DispatchKey::CPU));
229   }, "In operator registration: Tried to register multiple kernels with same dispatch key CPU for operator schema _test::dummy");
230 }
231 
TEST(OperatorRegistrationTest,givenMultipleCatchallKernels_whenRegisteringInSameOpCall_thenFails)232 TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenRegisteringInSameOpCall_thenFails) {
233   expectThrows<c10::Error>([&] {
234     auto registrar = c10::RegisterOperators()
235         .op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
236             .catchAllKernel<DummyKernel>()
237             .catchAllKernel<DummyKernel>());
238   }, "Tried to register multiple catch-all kernels for operator schema _test::dummy");
239 }
240 
TEST(OperatorRegistrationTest,whenRegisteringCPUTensorType_thenCanOnlyCallUnboxedWithCPUDispatchKey)241 TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnboxedWithCPUDispatchKey) {
242   bool called_kernel_cpu = false;
243   auto registrar= c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
244     .kernel<MockKernel>(c10::DispatchKey::CPU, &called_kernel_cpu));
245 
246   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
247   ASSERT_TRUE(op.has_value()); // assert schema is registered
248 
249   // Ensure that dispatcher doesn't take the dispatch key from the tensor but from the direct argument instead.
250   called_kernel_cpu = false;
251   callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA));
252   EXPECT_TRUE(called_kernel_cpu);
253 
254   // Ensure that disptach key from tensor is not used here.
255   called_kernel_cpu = false;
256   expectThrows<c10::Error>([&] {
257     callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU));
258   }, "Could not run '_test::dummy' with arguments from the 'CUDA'"
259   " backend.");
260 }
261 
expectedMessageForBackend(DispatchKey key)262 std::string expectedMessageForBackend(DispatchKey key) {
263   std::string key_str(c10::toString(key));
264   return "Could not run '_test::dummy' with arguments from the '" + key_str + "' backend";
265 }
266 
TEST(OperatorRegistrationTest,whenRegisteringMultipleKernelsInSameOpCallAndCalling_thenCallsCorrectKernel)267 TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsInSameOpCallAndCalling_thenCallsCorrectKernel) {
268   bool called_kernel1 = false;
269   bool called_kernel2 = false;
270   auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
271     .kernel<MockKernel>(c10::DispatchKey::CPU, &called_kernel1)
272     .kernel<MockKernel>(c10::DispatchKey::CUDA, &called_kernel2));
273 
274   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
275   ASSERT_TRUE(op.has_value()); // assert schema is registered
276 
277   called_kernel1 = called_kernel2 = false;
278   callOp(*op, dummyTensor(c10::DispatchKey::CPU));
279   EXPECT_TRUE(called_kernel1);
280   EXPECT_FALSE(called_kernel2);
281 
282   called_kernel1 = called_kernel2 = false;
283   callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
284   EXPECT_FALSE(called_kernel1);
285   EXPECT_TRUE(called_kernel2);
286 
287   // Test for out of tree lazy backends- ::Lazy key is now registered to TS backend in tree
288   for (c10::DispatchKey key : {c10::DispatchKey::XLA}) {
289     std::string expectMessage = expectedMessageForBackend(key);
290     expectThrows<c10::Error>([&] {
291       callOp(*op, dummyTensor(key));
292     }, expectMessage.c_str());
293 
294     // also assert that the error message contains the available tensor type ids, but don't assert their order
295     expectThrows<c10::Error>([&] {
296       callOp(*op, dummyTensor(key));
297     }, "CPU");
298     expectThrows<c10::Error>([&] {
299       callOp(*op, dummyTensor(key));
300     }, "CUDA");
301   }
302 }
303 
304 bool called_stackbased_kernel = false;
stackBasedKernel(const OperatorHandle &,c10::Stack * stack)305 void stackBasedKernel(const OperatorHandle&, c10::Stack* stack) {
306   called_stackbased_kernel = true;
307 }
308 
TEST(OperatorRegistrationTest,whenRegisteringMultipleKernelsByNameAndNoneCanInferSchema_thenFails)309 TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsByNameAndNoneCanInferSchema_thenFails) {
310   expectThrows<c10::Error>([&] {
311     auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
312       .kernel<&stackBasedKernel>(c10::DispatchKey::CPU)
313       .kernel<&stackBasedKernel>(c10::DispatchKey::CUDA)
314       .kernel<&stackBasedKernel>(c10::DispatchKey::XLA)
315       .kernel<&stackBasedKernel>(c10::DispatchKey::Lazy));
316   }, "Cannot infer operator schema for this kind of kernel in registration of operator _test::dummy");
317 }
318 
TEST(OperatorRegistrationTest,whenRegisteringMultipleKernelsBySchemaAndNoneCanInferSchema_thenSucceeds)319 TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsBySchemaAndNoneCanInferSchema_thenSucceeds) {
320   bool called_kernel = false;
321   auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
322     .kernel<&stackBasedKernel>(c10::DispatchKey::CPU)
323     .kernel<&stackBasedKernel>(c10::DispatchKey::CUDA)
324     .kernel<&stackBasedKernel>(c10::DispatchKey::XLA)
325     .kernel<&stackBasedKernel>(c10::DispatchKey::Lazy));
326 
327   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
328   ASSERT_TRUE(op.has_value()); // assert schema is registered
329 
330   called_kernel = called_stackbased_kernel = false;
331   callOp(*op, dummyTensor(c10::DispatchKey::CPU));
332   EXPECT_TRUE(called_stackbased_kernel);
333   EXPECT_FALSE(called_kernel);
334 
335   called_kernel = called_stackbased_kernel = false;
336   callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
337   EXPECT_TRUE(called_stackbased_kernel);
338   EXPECT_FALSE(called_kernel);
339 
340   for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
341     called_kernel = called_stackbased_kernel = false;
342     callOp(*op, dummyTensor(key));
343     EXPECT_TRUE(called_stackbased_kernel);
344     EXPECT_FALSE(called_kernel);
345   }
346 }
347 
TEST(OperatorRegistrationTest,whenRegisteringMultipleKernelsByNameAndOnlyOneCanInferSchema_thenSucceeds)348 TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsByNameAndOnlyOneCanInferSchema_thenSucceeds) {
349   bool called_kernel = false;
350   auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
351     .kernel<&stackBasedKernel>(c10::DispatchKey::CPU)
352     .kernel<MockKernel>(c10::DispatchKey::CUDA, &called_kernel)
353     .kernel<&stackBasedKernel>(c10::DispatchKey::XLA)
354     .kernel<&stackBasedKernel>(c10::DispatchKey::Lazy));
355 
356   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
357   ASSERT_TRUE(op.has_value()); // assert schema is registered
358 
359   called_kernel = called_stackbased_kernel = false;
360   callOp(*op, dummyTensor(c10::DispatchKey::CPU));
361   EXPECT_TRUE(called_stackbased_kernel);
362   EXPECT_FALSE(called_kernel);
363 
364   called_kernel = called_stackbased_kernel = false;
365   callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
366   EXPECT_FALSE(called_stackbased_kernel);
367   EXPECT_TRUE(called_kernel);
368 
369   for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
370     called_kernel = called_stackbased_kernel = false;
371     callOp(*op, dummyTensor(key));
372     EXPECT_TRUE(called_stackbased_kernel);
373     EXPECT_FALSE(called_kernel);
374   }
375 }
376 
TEST(OperatorRegistrationTest,whenRegisteringMultipleKernelsBySchemaAndOnlyOneCanInferSchema_thenSucceeds)377 TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsBySchemaAndOnlyOneCanInferSchema_thenSucceeds) {
378   bool called_kernel = false;
379   auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
380     .kernel<&stackBasedKernel>(c10::DispatchKey::CPU)
381     .kernel<MockKernel>(c10::DispatchKey::CUDA, &called_kernel)
382     .kernel<&stackBasedKernel>(c10::DispatchKey::XLA)
383     .kernel<&stackBasedKernel>(c10::DispatchKey::Lazy));
384 
385   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
386   ASSERT_TRUE(op.has_value()); // assert schema is registered
387 
388   called_kernel = called_stackbased_kernel = false;
389   callOp(*op, dummyTensor(c10::DispatchKey::CPU));
390   EXPECT_TRUE(called_stackbased_kernel);
391   EXPECT_FALSE(called_kernel);
392 
393   called_kernel = called_stackbased_kernel = false;
394   callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
395   EXPECT_FALSE(called_stackbased_kernel);
396   EXPECT_TRUE(called_kernel);
397 
398   for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
399     called_kernel = called_stackbased_kernel = false;
400     callOp(*op, dummyTensor(key));
401     EXPECT_TRUE(called_stackbased_kernel);
402     EXPECT_FALSE(called_kernel);
403   }
404 }
405 
406 struct DummyKernelWithIntParam final : OperatorKernel {
operator ()__anondbb8b74d0111::DummyKernelWithIntParam407   void operator()(Tensor, int64_t) {}
408 };
409 
TEST(OperatorRegistrationTest,whenRegisteringMismatchingKernelsInSameOpCall_thenFails)410 TEST(OperatorRegistrationTest, whenRegisteringMismatchingKernelsInSameOpCall_thenFails) {
411   bool called_kernel = false;
412   expectThrows<c10::Error>([&] {
413     auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
414       .kernel<DummyKernelWithIntParam>(c10::DispatchKey::CPU)
415       .kernel<MockKernel>(c10::DispatchKey::CUDA, &called_kernel));
416   }, "Mismatch in kernel C++ signatures");
417 }
418 
backend_fallback_kernel(const c10::OperatorHandle & op,c10::Stack * stack)419 void backend_fallback_kernel(const c10::OperatorHandle& op, c10::Stack* stack) {
420   (*stack)[1] = (*stack)[1].toStringRef() + op.schema().name();
421 }
422 
TEST(OperatorRegistrationTest,whenRegisteringBackendFallbackKernel_thenCanBeCalled)423 TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernel_thenCanBeCalled) {
424   auto registrar = c10::Dispatcher::singleton().registerFallback(c10::DispatchKey::CPU, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>(), "");
425 
426   auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()");
427   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
428   ASSERT_TRUE(op.has_value());
429   auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
430   EXPECT_EQ("hello _test::dummy", stack[1].toStringRef());
431 }
432 
TEST(OperatorRegistrationTest,whenRegisteringBackendFallbackKernelForWrongBackend_thenCannotBeCalled)433 TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelForWrongBackend_thenCannotBeCalled) {
434   auto registrar = c10::Dispatcher::singleton().registerFallback(c10::DispatchKey::CUDA, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>(), "");
435 
436   auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()");
437   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
438   ASSERT_TRUE(op.has_value());
439   expectThrows<c10::Error>([&] {
440     auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
441   }, "Could not run '_test::dummy' with arguments from the 'CPU' backend.");
442 }
443 
444 bool called = false;
445 
TEST(OperatorRegistrationTest,whenRegisteringBackendFallbackKernelAndRegularKernelForDifferentBackend_thenRegularKernelCanBeCalled)446 TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndRegularKernelForDifferentBackend_thenRegularKernelCanBeCalled) {
447   auto registrar = c10::Dispatcher::singleton().registerFallback(c10::DispatchKey::CPU, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>(), "");
448 
449   auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()", c10::RegisterOperators::options()
450       .kernel(c10::DispatchKey::CUDA, [] (Tensor, std::string) {
451         called = true;
452       }));
453   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
454   ASSERT_TRUE(op.has_value());
455 
456   called = false;
457   auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CUDA), "hello ");
458   EXPECT_TRUE(called);
459 }
460 
TEST(OperatorRegistrationTest,whenRegisteringBackendFallbackKernelAndRegularKernelForDifferentBackend_thenFallbackKernelCanBeCalled)461 TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndRegularKernelForDifferentBackend_thenFallbackKernelCanBeCalled) {
462   auto registrar = c10::Dispatcher::singleton().registerFallback(c10::DispatchKey::CPU, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>(), "");
463 
464   auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()", c10::RegisterOperators::options()
465       .kernel(c10::DispatchKey::CUDA, [] (Tensor, std::string) {
466         called = true;
467       }));
468   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
469   ASSERT_TRUE(op.has_value());
470 
471   called = false;
472   auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
473   EXPECT_FALSE(called);
474   EXPECT_EQ("hello _test::dummy", stack[1].toStringRef());
475 }
476 
TEST(OperatorRegistrationTest,whenRegisteringBackendFallbackKernelAndRegularKernelForSameBackend_thenCallsRegularKernel)477 TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndRegularKernelForSameBackend_thenCallsRegularKernel) {
478   auto registrar = c10::Dispatcher::singleton().registerFallback(c10::DispatchKey::CPU, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>(), "");
479 
480   auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()", c10::RegisterOperators::options()
481       .kernel(c10::DispatchKey::CPU, [] (Tensor, std::string) {
482         called = true;
483       }));
484   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
485   ASSERT_TRUE(op.has_value());
486 
487   called = false;
488   auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
489   EXPECT_TRUE(called);
490 }
491 
492 bool called_autograd = false;
493 bool called_nonautograd = false;
494 
nonautograd_kernel(Tensor a)495 void nonautograd_kernel(Tensor a) {
496   called_nonautograd = true;
497 }
498 
autograd_kernel(Tensor a)499 void autograd_kernel(Tensor a) {
500   called_autograd = true;
501 }
502 
TEST(OperatorRegistrationTest,whenRegisteringAutogradKernel_thenCanCallAutogradKernel)503 TEST(OperatorRegistrationTest, whenRegisteringAutogradKernel_thenCanCallAutogradKernel) {
504   auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
505     .kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd));
506 
507   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
508   ASSERT_TRUE(op.has_value());
509 
510   called_autograd = false;
511   expectThrows<c10::Error>([&] {
512     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
513   }, "Could not run '_test::dummy' with arguments from the 'CPU'"
514   " backend.");
515 
516   op->typed<void(Tensor)>().call(dummyTensor(DispatchKey::CPU, /*requires_grad=*/true));
517   EXPECT_TRUE(called_autograd);
518 }
519 
TEST(OperatorRegistrationTest,whenRegisteringAutogradKernelWithRegularKernel_thenCanCallAutogradKernel)520 TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallAutogradKernel) {
521   auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
522     .kernel<decltype(nonautograd_kernel), nonautograd_kernel>(DispatchKey::CPU)
523     .kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd));
524 
525   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
526   ASSERT_TRUE(op.has_value());
527 
528   called_nonautograd = called_autograd = false;
529   op->typed<void (Tensor)>().call(dummyTensor(DispatchKey::CPU, /*requires_grad=*/true));
530   EXPECT_FALSE(called_nonautograd);
531   EXPECT_TRUE(called_autograd);
532 }
533 
TEST(OperatorRegistrationTest,whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallCatchallKernel)534 TEST(
535     OperatorRegistrationTest,
536     whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallCatchallKernel) {
537   auto registrar = c10::RegisterOperators().op(
538       "_test::dummy(Tensor dummy) -> ()",
539       c10::RegisterOperators::options()
540           .catchAllKernel<decltype(nonautograd_kernel), nonautograd_kernel>()
541           .kernel<decltype(autograd_kernel), &autograd_kernel>(
542               DispatchKey::Autograd));
543 
544   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
545   ASSERT_TRUE(op.has_value());
546 
547   // catchAll now maps to CompositeImplicitAutograd which has
548   // higher precedence than Autograd
549   called_nonautograd = called_autograd = false;
550   op->typed<void(Tensor)>().call(
551       dummyTensor(DispatchKey::CPU, /*requires_grad=*/true));
552   EXPECT_TRUE(called_nonautograd);
553   EXPECT_FALSE(called_autograd);
554 
555   called_nonautograd = called_autograd = false;
556   op->typed<void(Tensor)>().call(dummyTensor(DispatchKey::CPU));
557   EXPECT_TRUE(called_nonautograd);
558   EXPECT_FALSE(called_autograd);
559 }
560 
TEST(OperatorRegistrationTest,AutogradBackendOverridesAutogradKernel)561 TEST(OperatorRegistrationTest, AutogradBackendOverridesAutogradKernel) {
562   auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
563     .kernel<decltype(nonautograd_kernel), &nonautograd_kernel>(DispatchKey::AutogradCPU)
564     .kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd));
565 
566   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
567   ASSERT_TRUE(op.has_value());
568 
569   expectThrows<c10::Error>([&] {
570     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
571   }, "Could not run '_test::dummy' with arguments from the 'CPU'"
572   " backend.");
573 
574   expectThrows<c10::Error>([&] {
575     callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
576   }, "Could not run '_test::dummy' with arguments from the 'CUDA'"
577   " backend.");
578 
579   called_nonautograd = called_autograd = false;
580   op->typed<void (Tensor)>().call(dummyTensor(DispatchKey::CPU, /*requires_grad=*/true));
581   EXPECT_TRUE(called_nonautograd);
582   EXPECT_FALSE(called_autograd);
583 
584   called_nonautograd = called_autograd = false;
585   op->typed<void (Tensor)>().call(dummyTensor(DispatchKey::CUDA, /*requires_grad=*/true));
586   EXPECT_TRUE(called_autograd);
587   EXPECT_FALSE(called_nonautograd);
588 }
589 
LazyBackendsAutogradOverridesAutogradKernel(DispatchKey key)590 void LazyBackendsAutogradOverridesAutogradKernel(DispatchKey key) {
591   auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
592     .kernel<decltype(nonautograd_kernel), &nonautograd_kernel>(c10::getAutogradKeyFromBackend(toBackendComponent(key)))
593     .kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd));
594 
595   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
596   ASSERT_TRUE(op.has_value());
597 
598   std::string expectedMessage = expectedMessageForBackend(key);
599   expectThrows<c10::Error>([&] {
600     callOp(*op, dummyTensor(key));
601   }, expectedMessage.c_str());
602 
603   called_nonautograd = called_autograd = false;
604   op->typed<void (Tensor)>().call(dummyTensor(key, /*requires_grad=*/true));
605   EXPECT_TRUE(called_nonautograd);
606   EXPECT_FALSE(called_autograd);
607 
608   called_nonautograd = called_autograd = false;
609   op->typed<void (Tensor)>().call(dummyTensor(DispatchKey::CPU, /*requires_grad=*/true));
610   EXPECT_TRUE(called_autograd);
611   EXPECT_FALSE(called_nonautograd);
612 }
613 
614 // no longer test ::Lazy key here
615 // since it is now registered to TS backend in-tree and thus behaves differently,
616 // does not throw the expected 'could not run..' messages
TEST(OperatorRegistrationTest,AutogradXLAOverridesAutogradKernel)617 TEST(OperatorRegistrationTest, AutogradXLAOverridesAutogradKernel) {
618   LazyBackendsAutogradOverridesAutogradKernel(DispatchKey::XLA);
619 }
620 
whenRegisterWithLazyBackendsAndCatchAll_AutogradLazyBackendsIsNotFilled(DispatchKey key)621 void whenRegisterWithLazyBackendsAndCatchAll_AutogradLazyBackendsIsNotFilled(DispatchKey key) {
622   {
623     auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
624       .catchAllKernel<decltype(nonautograd_kernel), nonautograd_kernel>());
625 
626     auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
627     ASSERT_TRUE(op.has_value());
628 
629     called_nonautograd = called_autograd = false;
630     op->typed<void (Tensor)>().call(dummyTensor(key, /*requires_grad=*/true));
631     EXPECT_TRUE(called_nonautograd);
632     EXPECT_FALSE(called_autograd);
633 
634     called_nonautograd = called_autograd = false;
635     op->typed<void (Tensor)>().call(dummyTensor(key));
636     EXPECT_FALSE(called_autograd);
637     EXPECT_TRUE(called_nonautograd);
638   }
639   {
640     auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
641       .kernel<decltype(autograd_kernel), &autograd_kernel>(key)
642       .catchAllKernel<decltype(nonautograd_kernel), nonautograd_kernel>());
643 
644     auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
645     ASSERT_TRUE(op.has_value());
646 
647     // When there's direct registration to XLA / Lazy backend, Autograd{XLA, Lazy} doesn't pick up catchAll
648     // kernel in precompute but just keep fallthrough kernel from backend fallback.
649     // Thus it falls through Autograd{XLA, Lazy} and reaches the kernel at XLA / Lazy key.
650     called_nonautograd = called_autograd = false;
651     op->typed<void (Tensor)>().call(dummyTensor(key, /*requires_grad=*/true));
652     EXPECT_FALSE(called_nonautograd);
653     EXPECT_TRUE(called_autograd);
654 
655     called_nonautograd = called_autograd = false;
656     op->typed<void (Tensor)>().call(dummyTensor(key));
657     EXPECT_TRUE(called_autograd);
658     EXPECT_FALSE(called_nonautograd);
659   }
660 }
661 
TEST(OperatorRegistrationTest,whenRegisterWithXLAKernelAndCatchAll_AutogradXLAIsNotFilled)662 TEST(OperatorRegistrationTest, whenRegisterWithXLAKernelAndCatchAll_AutogradXLAIsNotFilled) {
663   whenRegisterWithLazyBackendsAndCatchAll_AutogradLazyBackendsIsNotFilled(DispatchKey::XLA);
664 }
665 
TEST(OperatorRegistrationTest,whenRegisterWithLazyKernelAndCatchAll_AutogradLazyIsNotFilled)666 TEST(OperatorRegistrationTest, whenRegisterWithLazyKernelAndCatchAll_AutogradLazyIsNotFilled) {
667   whenRegisterWithLazyBackendsAndCatchAll_AutogradLazyBackendsIsNotFilled(DispatchKey::Lazy);
668 }
669 
TEST(OperatorRegistrationTest,whenregisteringwithinvalidoverloadname)670 TEST(OperatorRegistrationTest, whenregisteringwithinvalidoverloadname) {
671   expectThrows<c10::Error>([] {
672     auto registrar = c10::RegisterOperators().op("_test::dummy.default", c10::RegisterOperators::options()
673       .kernel(DispatchKey::CPU, [] (const int64_t&) {}));
674   }, "default is not a legal overload name for aten operators");
675   expectThrows<c10::Error>([] {
676     auto registrar = c10::RegisterOperators().op("_test::dummy.__name__", c10::RegisterOperators::options()
677       .kernel(DispatchKey::CPU, [] (const int64_t&) {}));
678   }, "__name__ is not a legal overload name for aten operators");
679 }
680 
TEST(OperatorRegistrationTest,givenLambdaKernel_whenRegisteringWithMismatchingCppSignatures_thenFails)681 TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringWithMismatchingCppSignatures_thenFails) {
682   expectThrows<c10::Error>([] {
683     auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
684       .kernel(DispatchKey::CPU, [] (const int64_t&) {})
685       .kernel(DispatchKey::CUDA, [] (int64_t) {}));
686   }, "Mismatch in kernel C++ signatures");
687 }
688 
TEST(OperatorRegistrationTest,givenLambdaKernel_whenRegisteringCatchAllAndBackendWithMismatchingCppSignatures_thenFails)689 TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringCatchAllAndBackendWithMismatchingCppSignatures_thenFails) {
690   expectThrows<c10::Error>([] {
691     auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
692       .kernel(DispatchKey::CPU, [] (const int64_t&) {})
693       .catchAllKernel([] (int64_t) {}));
694   }, "Mismatch in kernel C++ signatures");
695 }
696 
TEST(OperatorRegistrationTest,givenLambdaKernel_whenRegisteringBackendAndCatchAllWithMismatchingCppSignatures_thenFails)697 TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringBackendAndCatchAllWithMismatchingCppSignatures_thenFails) {
698   expectThrows<c10::Error>([] {
699     auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
700       .catchAllKernel([] (const int64_t&) {})
701       .kernel(DispatchKey::CPU, [] (int64_t) {}));
702   }, "Mismatch in kernel C++ signatures");
703 }
704 
TEST(OperatorRegistrationTest,givenLambdaKernel_whenAccessingWithMismatchingCppSignatures_thenFails)705 TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingWithMismatchingCppSignatures_thenFails) {
706   auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
707     .kernel(DispatchKey::CPU, [] (int64_t) {}));
708   expectThrows<c10::Error>([] {
709     c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "")
710       .typed<void(const int64_t&)>();
711   }, "Tried to access or call an operator with a wrong signature.\n  operator: _test::dummy(int _0) -> ()");
712 }
713 
TEST(OperatorRegistrationTest,givenLambdaKernel_whenAccessingCatchAllWithMismatchingCppSignatures_thenFails)714 TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingCatchAllWithMismatchingCppSignatures_thenFails) {
715   auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
716     .catchAllKernel([] (int64_t) {}));
717   expectThrows<c10::Error>([] {
718     c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "")
719       .typed<void(const int64_t&)>();
720   }, "Tried to access or call an operator with a wrong signature.\n  operator: _test::dummy(int _0) -> ()");
721 }
722 
TEST(OperatorRegistrationTest,givenTorchLibrary_whenRegisteringWithMismatchingCppSignatures_thenFails)723 TEST(OperatorRegistrationTest, givenTorchLibrary_whenRegisteringWithMismatchingCppSignatures_thenFails) {
724   auto m = MAKE_TORCH_LIBRARY(_test);
725   m.def("dummy(int a) -> ()");
726   m.impl("dummy", DispatchKey::CPU, [] (int64_t) {});
727   expectThrows<c10::Error>([&] {
728     m.impl("dummy", DispatchKey::CUDA, [] (const int64_t&) {});
729   }, "Mismatch in kernel C++ signatures");
730 }
731 
TEST(OperatorRegistrationTest,givenTorchLibrary_whenAccessingWithMismatchingCppSignatures_thenFails)732 TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingWithMismatchingCppSignatures_thenFails) {
733   auto m = MAKE_TORCH_LIBRARY(_test);
734   m.def("dummy(int a) -> ()");
735   m.impl("dummy", DispatchKey::CPU, [] (int64_t) {});
736   expectThrows<c10::Error>([] {
737     c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "")
738       .typed<void(const int64_t&)>();
739   }, "Tried to access or call an operator with a wrong signature.\n  operator: _test::dummy(int a) -> ()");
740 }
741 
TEST(OperatorRegistrationTest,givenTorchLibrary_whenAccessingCatchAllWithMismatchingCppSignatures_thenFails)742 TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingCatchAllWithMismatchingCppSignatures_thenFails) {
743   auto m = MAKE_TORCH_LIBRARY(_test);
744   m.def("dummy(int a) -> ()", [] (int64_t) {});
745   expectThrows<c10::Error>([] {
746     c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "")
747       .typed<void(const int64_t&)>();
748   }, "Tried to access or call an operator with a wrong signature.\n  operator: _test::dummy(int a) -> ()");
749 }
750 
751 /**
752  * This is used to check that a given type works correctly when passed as input
753  * to or as output from a kernel.
754  *
755  * Call ArgTypeTestKernel<Input, Output>::test(input, inputExpectation, output, outputExpectation, schema)
756  * to test that a kernel with `Input` as input type and `Output` as output types,
757  * when called with `input` fulfills `inputExpectation` inside the kernel, then
758  * returns `output` and the returned value fulfills `outputExpectation`.
759  *
760  * `inputExpectation` and `outputExpectation` should be lambdas that run
761  * googletest expect macros (or use other ways to assert the expectation is met).
762  *
763  * Optionally, you can specify the argument list part of a function schema
764  * (e.g. "(Tensor a) -> Tensor") as an additional argument to use when
765  * registering the kernel. In this case, the operator registration logic will
766  * check that the kernel function signature matches the one you specified.
767  */
768 struct TestModernAPI final {};
769 struct TestLegacyAPI final {};
770 struct TestModernAndLegacyAPI final {};
771 
772 template<class InputType, class OutputType = InputType>
773 struct ArgTypeTestKernel final : OperatorKernel {
ArgTypeTestKernel__anondbb8b74d0111::ArgTypeTestKernel774   explicit ArgTypeTestKernel(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output)
775   : input_(std::move(input)), inputExpectation_(std::move(inputExpectation)), output_(std::move(output)) {}
776 
operator ()__anondbb8b74d0111::ArgTypeTestKernel777   OutputType operator()(InputType input) const {
778     inputExpectation_(std::move(input));
779     return output_;
780   }
781 
test__anondbb8b74d0111::ArgTypeTestKernel782   static void test(TestModernAndLegacyAPI, InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const c10::Stack&)> outputExpectation, const std::string& schema) {
783     test(TestModernAPI(), input, inputExpectation, output, outputExpectation, schema);
784     test(TestLegacyAPI(), input, inputExpectation, output, outputExpectation, schema);
785   }
786 
test__anondbb8b74d0111::ArgTypeTestKernel787   static void test(TestModernAPI, InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const c10::Stack&)> outputExpectation, const std::string& schema) {
788     return test_([&] {
789       return c10::RegisterOperators().op("_test::my_op" + schema, c10::RegisterOperators::options().catchAllKernel<ArgTypeTestKernel>(input, inputExpectation, output));
790     }, input, inputExpectation, output, outputExpectation, schema);
791   }
792 
test__anondbb8b74d0111::ArgTypeTestKernel793   static void test(TestLegacyAPI, InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const c10::Stack&)> outputExpectation, const std::string& schema) {
794     return test_([&] {
795       return c10::RegisterOperators().op("_test::my_op" + schema, [=] (InputType input) -> OutputType {
796         inputExpectation(std::move(input));
797         return output;
798       });
799     }, input, inputExpectation, output, outputExpectation, schema);
800   }
801 
802 private:
test___anondbb8b74d0111::ArgTypeTestKernel803   static void test_(std::function<c10::RegisterOperators()> registration, InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const c10::Stack&)> outputExpectation, const std::string& schema) {
804     auto registry = registration();
805     auto op = Dispatcher::singleton().findSchema({"_test::my_op", ""});
806     ASSERT_TRUE(op.has_value()); // assert schema is registered
807     auto actualOutput = callOp(*op, input);
808     outputExpectation(actualOutput);
809   }
810 
811   InputType input_;
812   std::function<void(const InputType&)> inputExpectation_;
813   OutputType output_;
814   std::string schema_;
815 };
816 
817 template<class InputType, class OutputType = InputType>
818 struct testArgTypes final {
819   template<class APIType = TestModernAndLegacyAPI>
test__anondbb8b74d0111::testArgTypes820   static void test(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const IValue&)> outputExpectation, const std::string& schema) {
821     // Test with explicitly specified schema
822     ArgTypeTestKernel<InputType, OutputType>::test(
823       APIType(), input, inputExpectation, output, [&] (const c10::Stack& output) {
824         EXPECT_EQ(1, output.size());
825         outputExpectation(output[0]);
826       }, schema
827     );
828 
829     // Test with inferred schema
830     ArgTypeTestKernel<InputType, OutputType>::test(
831       APIType(), input, inputExpectation, output, [&] (const c10::Stack& output) {
832         EXPECT_EQ(1, output.size());
833         outputExpectation(output[0]);
834       }, ""
835     );
836 
837     // Test taking argument and returning nothing
838     ArgTypeTestKernel<InputType, std::tuple<>>::test(
839       APIType(), input, inputExpectation, {}, [] (const c10::Stack&) {}, ""
840     );
841 
842     // Test taking argument and returning multiple outputs
843     ArgTypeTestKernel<InputType, std::tuple<int64_t, OutputType>>::test(
844       APIType(), input, inputExpectation, std::tuple<int64_t, OutputType>{3, output}, [&] (const c10::Stack& output) {
845         EXPECT_EQ(2, output.size());
846         EXPECT_EQ(3, output[0].toInt());
847         outputExpectation(output[1]);
848       }, ""
849     );
850   }
851 };
852 
TEST(OperatorRegistrationTest,testAvailableArgTypes)853 TEST(OperatorRegistrationTest, testAvailableArgTypes) {
854   // TODO Test Scalar
855 
856   // primitive types
857   testArgTypes<double>::test(
858     1.5, [] (const double& v) {EXPECT_EQ(1.5, v);},
859     2.5, [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());},
860     "(float a) -> float");
861   testArgTypes<int64_t>::test(
862     1, [] (const int64_t& v) {EXPECT_EQ(1, v);},
863     2, [] (const IValue& v) {EXPECT_EQ(2, v.toInt());},
864     "(int a) -> int");
865   testArgTypes<bool>::test(
866     true, [] (const bool& v) {EXPECT_EQ(true, v);},
867     false, [] (const IValue& v) {EXPECT_EQ(false, v.toBool());},
868     "(bool a) -> bool");
869   testArgTypes<bool>::test(
870     false, [] (const bool& v) {EXPECT_EQ(false, v);},
871     true, [] (const IValue& v) {EXPECT_EQ(true, v.toBool());},
872     "(bool a) -> bool");
873   testArgTypes<std::string>::test(
874     "string1", [] (const std::string& v) {EXPECT_EQ("string1", v);},
875     "string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toStringRef());},
876     "(str a) -> str");
877   testArgTypes<Tensor>::test(
878     dummyTensor(c10::DispatchKey::CPU), [] (const Tensor& v) {EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v));},
879     dummyTensor(c10::DispatchKey::CUDA), [] (const IValue& v) {EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.toTensor()));},
880     "(Tensor a) -> Tensor");
881 
882 
883   // optional types (with has_value() == true)
884   testArgTypes<std::optional<double>>::test(
885     std::optional<double>(1.5), [] (const std::optional<double>& v) {EXPECT_EQ(1.5, v.value());},
886     std::optional<double>(2.5), [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());},
887     "(float? a) -> float?");
888   testArgTypes<std::optional<int64_t>>::test(
889     std::optional<int64_t>(1), [] (const std::optional<int64_t>& v) {EXPECT_EQ(1, v.value());},
890     std::optional<int64_t>(2), [] (const IValue& v) {EXPECT_EQ(2, v.toInt());},
891     "(int? a) -> int?");
892   testArgTypes<std::optional<bool>>::test(
893     std::optional<bool>(true), [] (const std::optional<bool>& v) {EXPECT_EQ(true, v.value());},
894     std::optional<bool>(false), [] (const IValue& v) {EXPECT_EQ(false, v.toBool());},
895     "(bool? a) -> bool?");
896   testArgTypes<std::optional<bool>>::test(
897     std::optional<bool>(false), [] (const std::optional<bool>& v) {EXPECT_EQ(false, v.value());},
898     std::optional<bool>(true), [] (const IValue& v) {EXPECT_EQ(true, v.toBool());},
899     "(bool? a) -> bool?");
900   testArgTypes<std::optional<std::string>>::test(
901     std::optional<std::string>("string1"), [] (const std::optional<std::string>& v) {EXPECT_EQ("string1", v.value());},
902     std::optional<std::string>("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toStringRef());},
903     "(str? a) -> str?");
904   testArgTypes<std::optional<Tensor>>::test(
905     std::optional<Tensor>(dummyTensor(c10::DispatchKey::CPU)), [] (const std::optional<Tensor>& v) {EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.value()));},
906     std::optional<Tensor>(dummyTensor(c10::DispatchKey::CUDA)), [] (const IValue& v) {EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.toTensor()));},
907     "(Tensor? a) -> Tensor?");
908 
909 
910   // optional types (with has_value() == false)
911   testArgTypes<std::optional<double>>::test(
912     std::optional<double>(std::nullopt), [] (const std::optional<double>& v) {EXPECT_FALSE(v.has_value());},
913     std::optional<double>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
914     "(float? a) -> float?");
915   testArgTypes<std::optional<int64_t>>::test(
916     std::optional<int64_t>(std::nullopt), [] (const std::optional<int64_t>& v) {EXPECT_FALSE(v.has_value());},
917     std::optional<int64_t>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
918     "(int? a) -> int?");
919   testArgTypes<std::optional<bool>>::test(
920     std::optional<bool>(std::nullopt), [] (const std::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
921     std::optional<bool>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
922     "(bool? a) -> bool?");
923   testArgTypes<std::optional<bool>>::test(
924     std::optional<bool>(std::nullopt), [] (const std::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
925     std::optional<bool>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
926     "(bool? a) -> bool?");
927   testArgTypes<std::optional<std::string>>::test(
928     std::optional<std::string>(std::nullopt), [] (const std::optional<std::string>& v) {EXPECT_FALSE(v.has_value());},
929     std::optional<std::string>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
930     "(str? a) -> str?");
931   testArgTypes<std::optional<Tensor>>::test(
932     std::optional<Tensor>(std::nullopt), [] (const std::optional<Tensor>& v) {EXPECT_FALSE(v.has_value());},
933     std::optional<Tensor>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
934     "(Tensor? a) -> Tensor?");
935 
936 
937   // list types (with empty list)
938   testArgTypes<c10::List<double>>::test(
939     c10::List<double>(), [] (const c10::List<double>& v) {EXPECT_EQ(0, v.size());},
940     c10::List<double>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<double>>().size());},
941     "(float[] a) -> float[]");
942   testArgTypes<c10::List<int64_t>>::test(
943     c10::List<int64_t>(), [] (const c10::List<int64_t>& v) {EXPECT_EQ(0, v.size());},
944     c10::List<int64_t>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<int64_t>>().size());},
945     "(int[] a) -> int[]");
946   testArgTypes<c10::List<bool>>::test(
947     c10::List<bool>(), [] (const c10::List<bool>& v) {EXPECT_EQ(0, v.size());},
948     c10::List<bool>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<bool>>().size());},
949     "(bool[] a) -> bool[]");
950   testArgTypes<c10::List<std::string>>::test(
951     c10::List<std::string>(), [] (const c10::List<std::string>& v) {EXPECT_EQ(0, v.size());},
952     c10::List<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toListRef().size());},
953     "(str[] a) -> str[]");
954 
955 
956   // list types (with non-empty list)
957   testArgTypes<c10::List<double>>::test(
958     c10::List<double>({1.5, 2.5}), [] (const c10::List<double>& v) {expectListEquals({1.5, 2.5}, v);},
959     c10::List<double>({3.5, 4.5}), [] (const IValue& v) {expectListEquals({3.5, 4.5}, v.to<c10::List<double>>());},
960     "(float[] a) -> float[]");
961   testArgTypes<c10::List<int64_t>>::test(
962     c10::List<int64_t>({1, 2}), [] (const c10::List<int64_t>& v) {expectListEquals({1, 2}, v);},
963     c10::List<int64_t>({3, 4}), [] (const IValue& v) {expectListEquals({3, 4}, v.to<c10::List<int64_t>>());},
964     "(int[] a) -> int[]");
965   testArgTypes<c10::List<bool>>::test(
966     c10::List<bool>({true, false}), [] (const c10::List<bool>& v) {expectListEquals({true, false}, v);},
967     c10::List<bool>({true, false}), [] (const IValue& v) {expectListEquals({true, false}, v.to<c10::List<bool>>());},
968     "(bool[] a) -> bool[]");
969   testArgTypes<c10::List<std::string>>::test(
970     c10::List<std::string>({"first", "second"}), [] (const c10::List<std::string>& v) {expectListEquals({"first", "second"}, v);},
971     c10::List<std::string>({"first", "second"}), [] (const IValue& v) {
972       EXPECT_EQ(2, v.toListRef().size());
973       EXPECT_EQ("first", v.toListRef()[0].toStringRef());
974       EXPECT_EQ("second", v.toListRef()[1].toStringRef());
975     },
976     "(str[] a) -> str[]");
977   testArgTypes<c10::List<Tensor>>::test(
978     c10::List<Tensor>({dummyTensor(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA)}), [] (const c10::List<Tensor>& v) {
979       EXPECT_EQ(2, v.size());
980       EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.get(0)));
981       EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.get(1)));
982     },
983     c10::List<Tensor>({dummyTensor(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU)}), [] (const IValue& v) {
984       EXPECT_EQ(2, v.to<c10::List<at::Tensor>>().size());
985       EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(0)));
986       EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(1)));
987     },
988     "(Tensor[] a) -> Tensor[]");
989 
990   // ArrayRef list types (with empty list)
991   testArgTypes<c10::ArrayRef<double>, c10::List<double>>::test(
992     c10::ArrayRef<double>(), [] (c10::ArrayRef<double> v) {EXPECT_EQ(0, v.size());},
993     c10::List<double>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<double>>().size());},
994     "(float[] a) -> float[]");
995   testArgTypes<c10::ArrayRef<int64_t>, c10::List<int64_t>>::test(
996     c10::ArrayRef<int64_t>(), [] (c10::ArrayRef<int64_t> v) {EXPECT_EQ(0, v.size());},
997     c10::List<int64_t>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<int64_t>>().size());},
998     "(int[] a) -> int[]");
999   testArgTypes<c10::ArrayRef<std::string>, c10::List<std::string>>::test(
1000     c10::ArrayRef<std::string>(), [] (c10::ArrayRef<std::string> v) {EXPECT_EQ(0, v.size());},
1001     c10::List<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toListRef().size());},
1002     "(str[] a) -> str[]");
1003 
1004 
1005   // list types (with non-empty list)
1006   testArgTypes<c10::ArrayRef<double>, c10::List<double>>::test(
1007     c10::ArrayRef<double>({1.5, 2.5}), [] (c10::ArrayRef<double> v) {expectListEquals({1.5, 2.5}, v);},
1008     c10::List<double>({3.5, 4.5}), [] (const IValue& v) {expectListEquals({3.5, 4.5}, v.to<c10::List<double>>());},
1009     "(float[] a) -> float[]");
1010   testArgTypes<c10::ArrayRef<int64_t>, c10::List<int64_t>>::test(
1011     c10::ArrayRef<int64_t>({1, 2}), [] (c10::ArrayRef<int64_t> v) {expectListEquals({1, 2}, v);},
1012     c10::List<int64_t>({3, 4}), [] (const IValue& v) {expectListEquals({3, 4}, v.to<c10::List<int64_t>>());},
1013     "(int[] a) -> int[]");
1014   testArgTypes<c10::ArrayRef<std::string>, c10::List<std::string>>::test(
1015     c10::ArrayRef<std::string>({"first", "second"}), [] (c10::ArrayRef<std::string> v) {expectListEquals({"first", "second"}, v);},
1016     c10::List<std::string>({"first", "second"}), [] (const IValue& v) {
1017       EXPECT_EQ(2, v.toListRef().size());
1018       EXPECT_EQ("first", v.toListRef()[0].toStringRef());
1019       EXPECT_EQ("second", v.toListRef()[1].toStringRef());
1020     },
1021     "(str[] a) -> str[]");
1022   testArgTypes<c10::ArrayRef<Tensor>, c10::List<Tensor>>::test(
1023     c10::ArrayRef<Tensor>({dummyTensor(c10::DispatchKey::CPUTensorId), dummyTensor(c10::DispatchKey::CUDATensorId)}), [] (c10::ArrayRef<Tensor> v) {
1024       EXPECT_EQ(2, v.size());
1025       EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v[0]));
1026       EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v[1]));
1027     },
1028     c10::List<Tensor>({dummyTensor(c10::DispatchKey::CUDATensorId), dummyTensor(c10::DispatchKey::CPUTensorId)}), [] (const IValue& v) {
1029       EXPECT_EQ(2, v.to<c10::List<at::Tensor>>().size());
1030       EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(0)));
1031       EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(1)));
1032     },
1033     "(Tensor[] a) -> Tensor[]");
1034 
1035 
1036   // std::array list types (with empty list)
1037   testArgTypes<std::array<double, 0>>::test(
1038     std::array<double, 0>(), [] (std::array<double, 0> v) {},
1039     std::array<double, 0>(), [] (const IValue& v) {EXPECT_EQ(0, (v.to<c10::List<double>>().size()));},
1040     "(float[0] a) -> float[0]");
1041   testArgTypes<std::array<int64_t, 0>>::test(
1042     std::array<int64_t, 0>(), [] (std::array<int64_t, 0> v) {},
1043     std::array<int64_t, 0>(), [] (const IValue& v) {EXPECT_EQ(0, (v.to<c10::List<int64_t>>().size()));},
1044     "(int[0] a) -> int[0]");
1045   testArgTypes<std::array<bool, 0>>::test(
1046     std::array<bool, 0>(), [] (std::array<bool, 0> v) {},
1047     std::array<bool, 0>(), [] (const IValue& v) {EXPECT_EQ(0, (v.to<std::array<bool, 0>>().size()));},
1048     "(bool[0] a) -> bool[0]");
1049   testArgTypes<std::array<std::string, 0>>::test(
1050     std::array<std::string, 0>(), [] (std::array<std::string, 0> v) {EXPECT_EQ(0, v.size());},
1051     std::array<std::string, 0>(), [] (const IValue& v) {EXPECT_EQ(0, v.toListRef().size());},
1052     "(str[0] a) -> str[0]");
1053 
1054 
1055   // std::array list types (with non-empty list)
1056   testArgTypes<std::array<double, 2>>::test(
1057     std::array<double, 2>({1.5, 2.5}), [] (std::array<double, 2> v) {expectListEquals({1.5, 2.5}, v);},
1058     std::array<double, 2>({3.5, 4.5}), [] (const IValue& v) {expectListEquals({3.5, 4.5}, v.to<std::array<double, 2>>());},
1059     "(float[2] a) -> float[2]");
1060   testArgTypes<std::array<int64_t, 2>>::test(
1061     std::array<int64_t, 2>({1, 2}), [] (std::array<int64_t, 2> v) {expectListEquals({1, 2}, v);},
1062     std::array<int64_t, 2>({3, 4}), [] (const IValue& v) {expectListEquals({3, 4}, v.to<std::array<int64_t, 2>>());},
1063     "(int[2] a) -> int[2]");
1064   testArgTypes<std::array<bool, 2>>::test(
1065     std::array<bool, 2>({true, false}), [] (std::array<bool, 2> v) {expectListEquals({true, false}, v);},
1066     std::array<bool, 2>({true, false}), [] (const IValue& v) {expectListEquals({true, false}, v.to<std::array<bool, 2>>());},
1067     "(bool[2] a) -> bool[2]");
1068   testArgTypes<std::array<std::string, 2>>::test(
1069     std::array<std::string, 2>({"first", "second"}), [] (std::array<std::string, 2> v) {expectListEquals({"first", "second"}, v);},
1070     std::array<std::string, 2>({"first", "second"}), [] (const IValue& v) {
1071       EXPECT_EQ(2, v.toListRef().size());
1072       EXPECT_EQ("first", v.toListRef()[0].toStringRef());
1073       EXPECT_EQ("second", v.toListRef()[1].toStringRef());
1074     },
1075     "(str[2] a) -> str[2]");
1076   testArgTypes<std::array<Tensor, 2>>::test(
1077     std::array<Tensor, 2>({dummyTensor(c10::DispatchKey::CPUTensorId), dummyTensor(c10::DispatchKey::CUDATensorId)}), [] (std::array<Tensor, 2> v) {
1078       EXPECT_EQ(2, v.size());
1079       EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v[0]));
1080       EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v[1]));
1081     },
1082     std::array<Tensor, 2>({dummyTensor(c10::DispatchKey::CUDATensorId), dummyTensor(c10::DispatchKey::CPUTensorId)}), [] (const IValue& v) {
1083       EXPECT_EQ(2, v.to<c10::List<at::Tensor>>().size());
1084       EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(0)));
1085       EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(1)));
1086     },
1087     "(Tensor[2] a) -> Tensor[2]");
1088 
1089 
1090   // deprecated list types (with empty list)
1091   testArgTypes<std::vector<double>>::test<TestLegacyAPI>(
1092     std::vector<double>(), [] (const std::vector<double>& v) {EXPECT_EQ(0, v.size());},
1093     std::vector<double>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<double>>().size());},
1094     "(float[] a) -> float[]");
1095   testArgTypes<std::vector<int64_t>>::test<TestLegacyAPI>(
1096     std::vector<int64_t>(), [] (const std::vector<int64_t>& v) {EXPECT_EQ(0, v.size());},
1097     std::vector<int64_t>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<int64_t>>().size());},
1098     "(int[] a) -> int[]");
1099   //Note: vector<bool> is not supported, use List<bool> instead.
1100   testArgTypes<std::vector<std::string>>::test<TestLegacyAPI>(
1101     std::vector<std::string>(), [] (const std::vector<std::string>& v) {EXPECT_EQ(0, v.size());},
1102     std::vector<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toListRef().size());},
1103     "(str[] a) -> str[]");
1104 
1105 
1106   // deprecated list types (with non-empty list)
1107   testArgTypes<std::vector<double>>::test<TestLegacyAPI>(
1108     std::vector<double>({1.5, 2.5}), [] (const std::vector<double>& v) {expectListEquals({1.5, 2.5}, v);},
1109     std::vector<double>({3.5, 4.5}), [] (const IValue& v) {expectListEquals({3.5, 4.5}, v.to<c10::List<double>>());},
1110     "(float[] a) -> float[]");
1111   testArgTypes<std::vector<int64_t>>::test<TestLegacyAPI>(
1112     std::vector<int64_t>({1, 2}), [] (const std::vector<int64_t>& v) {expectListEquals({1, 2}, v);},
1113     std::vector<int64_t>({3, 4}), [] (const IValue& v) {expectListEquals({3, 4}, v.to<c10::List<int64_t>>());},
1114     "(int[] a) -> int[]");
1115   //Note: vector<bool> is not supported, use List<bool> instead.
1116   testArgTypes<std::vector<std::string>>::test<TestLegacyAPI>(
1117     std::vector<std::string>({"first", "second"}), [] (const std::vector<std::string>& v) {expectListEquals({"first", "second"}, v);},
1118     std::vector<std::string>({"first", "second"}), [] (const IValue& v) {
1119       EXPECT_EQ(2, v.toListRef().size());
1120       EXPECT_EQ("first", v.toListRef()[0].toStringRef());
1121       EXPECT_EQ("second", v.toListRef()[1].toStringRef());
1122     },
1123     "(str[] a) -> str[]");
1124   testArgTypes<std::vector<Tensor>>::test<TestLegacyAPI>(
1125     std::vector<Tensor>({dummyTensor(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA)}), [] (const std::vector<Tensor>& v) {
1126       EXPECT_EQ(2, v.size());
1127       EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.at(0)));
1128       EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.at(1)));
1129     },
1130     std::vector<Tensor>({dummyTensor(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU)}), [] (const IValue& v) {
1131       EXPECT_EQ(2, v.to<c10::List<at::Tensor>>().size());
1132       EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(0)));
1133       EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(1)));
1134     },
1135     "(Tensor[] a) -> Tensor[]");
1136 
1137   // Test optional of list (with nullopt)
1138   testArgTypes<std::optional<c10::List<int64_t>>>::test(
1139     std::optional<c10::List<int64_t>>(std::nullopt), [] (const std::optional<c10::List<int64_t>>& v) {EXPECT_FALSE(v.has_value());},
1140     std::optional<c10::List<int64_t>>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
1141     "(int[]? a) -> int[]?");
1142 
1143   // Test optional of list (with empty list)
1144   testArgTypes<std::optional<c10::List<int64_t>>>::test(
1145     std::optional<c10::List<int64_t>>(c10::List<int64_t>({})), [] (const std::optional<c10::List<int64_t>>& v) {EXPECT_EQ(0, v.value().size());},
1146     std::optional<c10::List<int64_t>>(c10::List<int64_t>({})), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<int64_t>>().size());},
1147     "(int[]? a) -> int[]?");
1148 
1149   // Test optional of list (with values)
1150   testArgTypes<std::optional<c10::List<int64_t>>>::test(
1151     std::optional<c10::List<int64_t>>(c10::List<int64_t>({1, 2})), [] (const std::optional<c10::List<int64_t>>& v) {expectListEquals({1, 2}, v.value());},
1152     std::optional<c10::List<int64_t>>(c10::List<int64_t>({3, 4})), [] (const IValue& v) {expectListEquals({3, 4}, v.to<c10::List<int64_t>>());},
1153     "(int[]? a) -> int[]?");
1154 
1155   // Test list of optional (with empty list)
1156   testArgTypes<c10::List<::std::optional<int64_t>>>::test(
1157     c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({})), [] (const c10::List<::std::optional<int64_t>>& v) {EXPECT_EQ(0, v.size());},
1158     c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({})), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<::std::optional<int64_t>>>().size());},
1159     "(int?[] a) -> int?[]");
1160 
1161   // Test list of optional (with values)
1162   testArgTypes<c10::List<::std::optional<int64_t>>>::test(
1163     c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({3, std::nullopt, 2})), [] (const c10::List<::std::optional<int64_t>>& v) {expectListEquals<std::optional<int64_t>>({3, std::nullopt, 2}, v);},
1164     c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({3, std::nullopt, 2})), [] (const IValue& v) {expectListEquals<std::optional<int64_t>>({3, std::nullopt, 2}, v.to<c10::List<::std::optional<int64_t>>>());},
1165     "(int?[] a) -> int?[]");
1166 
1167   // dict types
1168   c10::Dict<std::string, std::string> str_dict;
1169   str_dict.insert("key1", "value1");
1170   str_dict.insert("key2", "value2");
1171   testArgTypes<c10::Dict<std::string, std::string>>::test(
1172     str_dict, [] (c10::Dict<std::string, std::string> v) {
1173       EXPECT_EQ(2, v.size());
1174       EXPECT_EQ("value1", v.at("key1"));
1175       EXPECT_EQ("value2", v.at("key2"));
1176     },
1177     str_dict, [] (const IValue& v) {
1178       c10::Dict<std::string, std::string> dict = c10::impl::toTypedDict<std::string, std::string>(v.toGenericDict());
1179       EXPECT_EQ(2, dict.size());
1180       EXPECT_EQ("value1", dict.at("key1"));
1181       EXPECT_EQ("value2", dict.at("key2"));
1182     },
1183     "(Dict(str, str) a) -> Dict(str, str)");
1184   c10::Dict<int64_t, Tensor> tensor_dict;
1185   tensor_dict.insert(1, dummyTensor(c10::DispatchKey::CPU));
1186   tensor_dict.insert(2, dummyTensor(c10::DispatchKey::CUDA));
1187   testArgTypes<c10::Dict<int64_t, Tensor>>::test(
1188     tensor_dict, [] (c10::Dict<int64_t, Tensor> v) {
1189       EXPECT_EQ(2, v.size());
1190       EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.at(1)));
1191       EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.at(2)));
1192     },
1193     tensor_dict, [] (const IValue& v) {
1194       c10::Dict<int64_t, Tensor> dict = c10::impl::toTypedDict<int64_t, Tensor>(v.toGenericDict());
1195       EXPECT_EQ(2, dict.size());
1196       EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(dict.at(1)));
1197       EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(dict.at(2)));
1198     },
1199     "(Dict(int, Tensor) a) -> Dict(int, Tensor)");
1200 
1201   // deprecated dict types
1202   std::unordered_map<std::string, std::string> str_map;
1203   str_map.emplace("key1", "value1");
1204   str_map.emplace("key2", "value2");
1205   testArgTypes<std::unordered_map<std::string, std::string>>::test<TestLegacyAPI>(
1206     str_map, [] (std::unordered_map<std::string, std::string> v) {
1207       EXPECT_EQ(2, v.size());
1208       EXPECT_EQ("value1", v.at("key1"));
1209       EXPECT_EQ("value2", v.at("key2"));
1210     },
1211     str_map, [] (const IValue& v) {
1212       c10::Dict<std::string, std::string> dict = c10::impl::toTypedDict<std::string, std::string>(v.toGenericDict());
1213       EXPECT_EQ(2, dict.size());
1214       EXPECT_EQ("value1", dict.at("key1"));
1215       EXPECT_EQ("value2", dict.at("key2"));
1216     },
1217     "(Dict(str, str) a) -> Dict(str, str)");
1218   std::unordered_map<int64_t, Tensor> tensor_map;
1219   tensor_map.emplace(1, dummyTensor(c10::DispatchKey::CPU));
1220   tensor_map.emplace(2, dummyTensor(c10::DispatchKey::CUDA));
1221   testArgTypes<std::unordered_map<int64_t, Tensor>>::test<TestLegacyAPI>(
1222     tensor_map, [] (std::unordered_map<int64_t, Tensor> v) {
1223       EXPECT_EQ(2, v.size());
1224       EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.at(1)));
1225       EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.at(2)));
1226     },
1227     tensor_map, [] (const IValue& v) {
1228       c10::Dict<int64_t, Tensor> dict = c10::impl::toTypedDict<int64_t, Tensor>(v.toGenericDict());
1229       EXPECT_EQ(2, dict.size());
1230       EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(dict.at(1)));
1231       EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(dict.at(2)));
1232     },
1233     "(Dict(int, Tensor) a) -> Dict(int, Tensor)");
1234 
1235   // weird deeply nested type
1236   using DeeplyNestedType = c10::List<c10::Dict<std::string, c10::List<::std::optional<c10::Dict<int64_t, std::string>>>>>;
1237   auto makeDeeplyNestedObject = [] () -> DeeplyNestedType {
1238     c10::Dict<int64_t, std::string> inner3;
1239     inner3.insert(1, "1");
1240     c10::List<::std::optional<c10::Dict<int64_t, std::string>>> inner2;
1241     inner2.push_back(std::move(inner3));
1242     c10::Dict<std::string, c10::List<::std::optional<c10::Dict<int64_t, std::string>>>> inner1;
1243     inner1.insert("key", std::move(inner2));
1244     c10::List<c10::Dict<std::string, c10::List<::std::optional<c10::Dict<int64_t, std::string>>>>> result;
1245     result.push_back(inner1);
1246     return result;
1247   };
1248   testArgTypes<DeeplyNestedType>::test(
1249     makeDeeplyNestedObject(), [] (const DeeplyNestedType& v) {EXPECT_EQ("1", v.get(0).at("key").get(0).value().at(1));},
1250     makeDeeplyNestedObject(), [] (const IValue& v) {EXPECT_EQ("1", v.to<DeeplyNestedType>().get(0).at("key").get(0).value().at(1));},
1251     "(Dict(str, Dict(int, str)?[])[] a) -> Dict(str, Dict(int, str)?[])[]");
1252 }
1253 
TEST(NewOperatorRegistrationTest,erroroutwithinvalidoverloadname)1254 TEST(NewOperatorRegistrationTest, erroroutwithinvalidoverloadname) {
1255   auto m = MAKE_TORCH_LIBRARY(_test);
1256   expectThrows<c10::Error>([&] {
1257    m.def("dummy.default(Tensor self) -> Tensor");
1258   }, "default is not a legal overload name for aten operators");
1259   expectThrows<c10::Error>([&] {
1260    m.def("dummy.__name__(Tensor self) -> Tensor");
1261   }, "__name__ is not a legal overload name for aten operators");
1262 }
1263 
TEST(NewOperatorRegistrationTest,testBasics)1264 TEST(NewOperatorRegistrationTest, testBasics) {
1265   auto m = MAKE_TORCH_LIBRARY(_test);
1266   m.def("dummy(Tensor self) -> Tensor");
1267   m.def("dummy1(Tensor self) -> Tensor");
1268   m.def("dummy2(Tensor self) -> Tensor");
1269   m.def("dummy3(Tensor self, Tensor other) -> Tensor", [](const Tensor& self, const Tensor& other) { return self; });
1270   m.def("dummy4", [](const Tensor& self, const Tensor& other) { return other; });
1271   m.impl("dummy", c10::DeviceType::CPU, [](const Tensor& self) { return self; });
1272   m.impl("dummy", c10::DeviceType::XLA, [](const Tensor& self) { return self; });
1273   m.impl("dummy", c10::DeviceType::Lazy, [](const Tensor& self) { return self; });
1274   // Internal API
1275   m.impl("dummy2", c10::DispatchKey::CPU, [](const Tensor& self) { return self; });
1276   m.impl("dummy2", c10::DispatchKey::XLA, [](const Tensor& self) { return self; });
1277   m.impl("dummy2", c10::DispatchKey::Lazy, [](const Tensor& self) { return self; });
1278 
1279   ASSERT_TRUE(Dispatcher::singleton().findSchema({"_test::dummy", ""}).has_value());
1280   // Should have a schema even if there are no impls
1281   ASSERT_TRUE(Dispatcher::singleton().findSchema({"_test::dummy1", ""}).has_value());
1282   ASSERT_TRUE(Dispatcher::singleton().findSchema({"_test::dummy2", ""}).has_value());
1283   ASSERT_TRUE(Dispatcher::singleton().findSchema({"_test::dummy3", ""}).has_value());
1284   ASSERT_TRUE(Dispatcher::singleton().findSchema({"_test::dummy4", ""}).has_value());
1285 }
1286 
TEST(NewOperatorRegistrationTest,importTopLevel)1287 TEST(NewOperatorRegistrationTest, importTopLevel) {
1288   auto m = MAKE_TORCH_LIBRARY(test);
1289   m.def("def1(Tensor self) -> Tensor");
1290   m.def("def2(Tensor self) -> Tensor", [](const Tensor& x) { return x; });
1291   m.def("def3", [](const Tensor& x) { return x; });
1292 
1293   auto m2 = MAKE_TORCH_LIBRARY_IMPL(test, CatchAll);
1294   m2.impl("impl1", [](const Tensor& x) { return x; });
1295 
1296   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def1", ""}).has_value());
1297   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def2", ""}).has_value());
1298   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def3", ""}).has_value());
1299   ASSERT_TRUE(Dispatcher::singleton().findOp({"test::def1", ""}).has_value());
1300   ASSERT_TRUE(Dispatcher::singleton().findOp({"test::def2", ""}).has_value());
1301   ASSERT_TRUE(Dispatcher::singleton().findOp({"test::def3", ""}).has_value());
1302   ASSERT_TRUE(Dispatcher::singleton().findOp({"test::impl1", ""}).has_value());
1303 }
1304 
TEST(NewOperatorRegistrationTest,overload)1305 TEST(NewOperatorRegistrationTest, overload) {
1306   auto m = MAKE_TORCH_LIBRARY(test);
1307   m.def("fn(Tensor self) -> Tensor");
1308   m.def("fn.overload1(Tensor self, Tensor other) -> Tensor");
1309   m.def("fn.overload2(Tensor self, Tensor other, Tensor alpha) -> Tensor");
1310 
1311   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::fn", ""}).has_value());
1312   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::fn", "overload1"}).has_value());
1313   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::fn", "overload2"}).has_value());
1314 }
1315 
TEST(NewOperatorRegistrationTest,importNamespace)1316 TEST(NewOperatorRegistrationTest, importNamespace) {
1317   auto m = MAKE_TORCH_LIBRARY(test);
1318   m.def("def1(Tensor self) -> Tensor");
1319   m.def("def2(Tensor self) -> Tensor", [](const Tensor& x) { return x; });
1320   m.def("def3", [](const Tensor& x) { return x; });
1321   m.impl("impl1", [](const Tensor& x) { return x; });
1322   expectThrows<c10::Error>([&] {
1323     m.def("retest::def1(Tensor self) -> Tensor");
1324   }, "");
1325 
1326   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def1", ""}).has_value());
1327   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def2", ""}).has_value());
1328   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def3", ""}).has_value());
1329   ASSERT_TRUE(Dispatcher::singleton().findOp({"test::impl1", ""}).has_value());
1330 }
1331 
TEST(NewOperatorRegistrationTest,schema)1332 TEST(NewOperatorRegistrationTest, schema) {
1333   auto m = MAKE_TORCH_LIBRARY(test);
1334   m.def("def1(Tensor self) -> Tensor");
1335   m.def(torch::schema("def2(Tensor self) -> Tensor"));
1336   m.def(torch::schema("def3(Tensor self) -> Tensor", c10::AliasAnalysisKind::PURE_FUNCTION));
1337   m.def(torch::jit::parseSchema("def4(Tensor self) -> Tensor"));
1338 
1339   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def1", ""}).has_value());
1340   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def2", ""}).has_value());
1341   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def3", ""}).has_value());
1342   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def4", ""}).has_value());
1343 
1344   EXPECT_EQ(Dispatcher::singleton().findSchema({"test::def1", ""})->schema().aliasAnalysis(), c10::AliasAnalysisKind::FROM_SCHEMA);
1345   EXPECT_EQ(Dispatcher::singleton().findSchema({"test::def2", ""})->schema().aliasAnalysis(), c10::AliasAnalysisKind::FROM_SCHEMA);
1346   EXPECT_EQ(Dispatcher::singleton().findSchema({"test::def3", ""})->schema().aliasAnalysis(), c10::AliasAnalysisKind::PURE_FUNCTION);
1347   ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def4", ""})->schema().isDefaultAliasAnalysisKind());
1348 }
1349 
TEST(NewOperatorRegistrationTest,whenRegisteringBackendFallbackKernelAndCatchallKernelForSameBackend_thenCallsFallbackKernel)1350 TEST(NewOperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndCatchallKernelForSameBackend_thenCallsFallbackKernel) {
1351   auto m1 = MAKE_TORCH_LIBRARY_IMPL(_, CPU);
1352   m1.fallback(CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
1353 
1354   bool called = false;
1355   auto m = MAKE_TORCH_LIBRARY(test);
1356   m.def("fn(Tensor t, str input) -> ()");
1357   m.impl("fn", [&] (Tensor, std::string) { called = true; });
1358 
1359   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1360   ASSERT_TRUE(op.has_value());
1361 
1362   called = false;
1363   auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
1364   // CatchAll now maps to CompositeImplicitAutograd and has higher precedence than backend fallback.
1365   EXPECT_TRUE(called);
1366 }
1367 
TEST(NewOperatorRegistrationTest,whenRegisteringAutogradKernelWithRegularKernel_thenCanCallRegularKernel)1368 TEST(NewOperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallRegularKernel) {
1369   auto m = MAKE_TORCH_LIBRARY(test);
1370   m.def("fn(Tensor dummy) -> ()");
1371   m.impl("fn", c10::DispatchKey::CPU, nonautograd_kernel);
1372   m.impl("fn", c10::DispatchKey::Autograd, autograd_kernel);
1373 
1374   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1375   ASSERT_TRUE(op.has_value());
1376 
1377   called_nonautograd = called_autograd = false;
1378   callOp(*op, dummyTensor(DispatchKey::CPU));
1379   EXPECT_TRUE(called_nonautograd);
1380   EXPECT_FALSE(called_autograd);
1381 }
1382 
TEST(NewOperatorRegistrationTest,dispatchWithCompositeImplicitAutogradKernel)1383 TEST(NewOperatorRegistrationTest, dispatchWithCompositeImplicitAutogradKernel) {
1384   bool math_called = false;
1385   auto m = MAKE_TORCH_LIBRARY(test);
1386   m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }));
1387 
1388   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1389   ASSERT_TRUE(op.has_value());
1390 
1391   {
1392     ASSERT_FALSE(math_called);
1393     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1394     ASSERT_TRUE(math_called);
1395   }
1396 
1397   {
1398     math_called = false;
1399     callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1400     ASSERT_TRUE(math_called);
1401   }
1402 
1403   for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1404     math_called = false;
1405     callOp(*op, dummyTensor(key));
1406     ASSERT_TRUE(math_called);
1407   }
1408 
1409   for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1410     math_called = false;
1411     callOp(*op, dummyTensor(key, /*requires_grad=*/true));
1412     ASSERT_TRUE(math_called);
1413   }
1414 
1415   {
1416     math_called = false;
1417     callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU));
1418     ASSERT_TRUE(math_called);
1419   }
1420 
1421   {
1422     math_called = false;
1423     callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true));
1424     ASSERT_TRUE(math_called);
1425   }
1426 }
1427 
TEST(NewOperatorRegistrationTest,dispatchWithCompositeImplicitAutogradAndAutogradKernel)1428 TEST(NewOperatorRegistrationTest, dispatchWithCompositeImplicitAutogradAndAutogradKernel) {
1429   bool math_called = false;
1430   bool autograd_called = false;
1431   auto m = MAKE_TORCH_LIBRARY(test);
1432   m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }));
1433   m.impl("fn", c10::DispatchKey::Autograd, [&](const Tensor& x) { autograd_called = true; return x; });
1434 
1435   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1436   ASSERT_TRUE(op.has_value());
1437 
1438   // CompositeImplicitAutograd has higher precedence than Autograd
1439   {
1440     math_called = autograd_called = false;
1441     callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1442     ASSERT_TRUE(math_called);
1443     ASSERT_FALSE(autograd_called);
1444   }
1445 
1446   {
1447     math_called = autograd_called = false;
1448     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1449     ASSERT_TRUE(math_called);
1450     ASSERT_FALSE(autograd_called);
1451   }
1452 }
1453 
TEST(NewOperatorRegistrationTest,dispatchWithCompositeImplicitAutogradAndCatchAllKernel)1454 TEST(NewOperatorRegistrationTest, dispatchWithCompositeImplicitAutogradAndCatchAllKernel) {
1455   bool math_called = false;
1456   bool catchall_called = false;
1457   auto m = MAKE_TORCH_LIBRARY(test);
1458   m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }));
1459   m.impl("fn", [&](const Tensor& x) { catchall_called = true; return x; });
1460 
1461   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1462   ASSERT_TRUE(op.has_value());
1463 
1464   // catchAll now maps to CompositeImplicitAutograd, which means we have two registrations to CompositeImplicitAutograd key.
1465   // The last registration is used.
1466   {
1467     catchall_called = math_called = false;
1468     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1469     ASSERT_FALSE(math_called);
1470     ASSERT_TRUE(catchall_called);
1471   }
1472 
1473   {
1474     catchall_called = math_called = false;
1475     callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1476     ASSERT_FALSE(math_called);
1477     ASSERT_TRUE(catchall_called);
1478   }
1479 }
1480 
TEST(NewOperatorRegistrationTest,AutogradBackendOverridesCompositeImplicitAutogradKernel)1481 TEST(NewOperatorRegistrationTest, AutogradBackendOverridesCompositeImplicitAutogradKernel) {
1482   bool math_called = false;
1483   bool autograd_called = false;
1484   auto m = MAKE_TORCH_LIBRARY(test);
1485   m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }));
1486   m.impl("fn", c10::DispatchKey::AutogradCPU, [&](const Tensor& x) { autograd_called = true; return x; });
1487 
1488   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1489   ASSERT_TRUE(op.has_value());
1490 
1491   {
1492     math_called = autograd_called = false;
1493     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1494     ASSERT_TRUE(math_called);
1495     ASSERT_FALSE(autograd_called);
1496   }
1497 
1498   {
1499     math_called = autograd_called = false;
1500     callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1501     ASSERT_TRUE(autograd_called);
1502     ASSERT_FALSE(math_called);
1503   }
1504 
1505   {
1506     math_called = autograd_called = false;
1507     callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
1508     ASSERT_TRUE(math_called);
1509     ASSERT_FALSE(autograd_called);
1510   }
1511 
1512   {
1513     math_called = autograd_called = false;
1514     callOp(*op, dummyTensor(c10::DispatchKey::CUDA, /*requires_grad=*/true));
1515     ASSERT_TRUE(math_called);
1516     ASSERT_FALSE(autograd_called);
1517   }
1518 }
1519 
TEST(NewOperatorRegistrationTest,BackendOverridesCompositeImplicitAutogradKernel)1520 TEST(NewOperatorRegistrationTest, BackendOverridesCompositeImplicitAutogradKernel) {
1521   bool math_called = false;
1522   bool backend_called = false;
1523   auto m = MAKE_TORCH_LIBRARY(test);
1524   m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }));
1525   m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { backend_called = true; return x; });
1526 
1527   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1528   ASSERT_TRUE(op.has_value());
1529 
1530   {
1531     math_called = backend_called = false;
1532     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1533     ASSERT_TRUE(backend_called);
1534     ASSERT_FALSE(math_called);
1535   }
1536 
1537   {
1538     // Fallthrough AutogradCPU and reaches CPU
1539     math_called = backend_called = false;
1540     callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1541     ASSERT_TRUE(backend_called);
1542     ASSERT_FALSE(math_called);
1543   }
1544 
1545   {
1546     math_called = backend_called = false;
1547     callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
1548     ASSERT_TRUE(math_called);
1549     ASSERT_FALSE(backend_called);
1550   }
1551 
1552   {
1553     math_called = backend_called = false;
1554     callOp(*op, dummyTensor(c10::DispatchKey::CUDA, /*requires_grad=*/true));
1555     ASSERT_TRUE(math_called);
1556     ASSERT_FALSE(backend_called);
1557   }
1558 }
1559 
TEST(NewOperatorRegistrationTest,dispatchWithCompositeExplicitAutogradKernel)1560 TEST(NewOperatorRegistrationTest, dispatchWithCompositeExplicitAutogradKernel) {
1561   bool called = false;
1562   auto m = MAKE_TORCH_LIBRARY(test);
1563   m.def("fn", torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, [&](const Tensor& x) { called = true; return x; }));
1564 
1565   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1566   ASSERT_TRUE(op.has_value());
1567 
1568   {
1569     ASSERT_FALSE(called);
1570     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1571     ASSERT_TRUE(called);
1572   }
1573 
1574   {
1575     called = false;
1576     // AutogradCPU is fallthrough, calls CPU kernel
1577     callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1578     ASSERT_TRUE(called);
1579   }
1580 
1581   for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1582     called = false;
1583     callOp(*op, dummyTensor(key));
1584     ASSERT_TRUE(called);
1585   }
1586 
1587   for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1588     called = false;
1589     // Autograd{XLA, Lazy} is fallthrough, calls XLA / Lazy kernel
1590     callOp(*op, dummyTensor(key, /*requires_grad=*/true));
1591     ASSERT_TRUE(called);
1592   }
1593 
1594   {
1595     called = false;
1596     callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU));
1597     ASSERT_TRUE(called);
1598   }
1599 
1600   {
1601     called = false;
1602     // AutogradCPU is fallthrough, calls CPU kernel
1603     callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true));
1604     ASSERT_TRUE(called);
1605   }
1606 }
1607 
TEST(NewOperatorRegistrationTest,dispatchWithCompositeExplicitAutogradAndCompositeImplicitAutogradKernel)1608 TEST(NewOperatorRegistrationTest, dispatchWithCompositeExplicitAutogradAndCompositeImplicitAutogradKernel) {
1609   bool backend_called = false;
1610   bool math_called = false;
1611   auto m = MAKE_TORCH_LIBRARY(test);
1612   m.def("fn", torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, [&](const Tensor& x) { backend_called = true; return x; }));
1613   m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; });
1614 
1615   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1616   ASSERT_TRUE(op.has_value());
1617 
1618   {
1619     backend_called = math_called = false;
1620     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1621     ASSERT_TRUE(backend_called);
1622     ASSERT_FALSE(math_called);
1623   }
1624 
1625   {
1626     backend_called = math_called = false;
1627     // AutogradCPU is fallthrough, calls CPU kernel
1628     callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1629     ASSERT_FALSE(math_called);
1630     ASSERT_TRUE(backend_called);
1631   }
1632 
1633   for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1634     backend_called = math_called = false;
1635     callOp(*op, dummyTensor(key));
1636     ASSERT_TRUE(backend_called);
1637     ASSERT_FALSE(math_called);
1638   }
1639 
1640   for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1641     backend_called = math_called = false;
1642     // Autograd{XLA, Lazy} is fallthrough, calls XLA / Lazy kernel
1643     callOp(*op, dummyTensor(key, /*requires_grad=*/true));
1644     ASSERT_FALSE(math_called);
1645     ASSERT_TRUE(backend_called);
1646   }
1647 
1648   {
1649     backend_called = math_called = false;
1650     callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU));
1651     ASSERT_TRUE(backend_called);
1652     ASSERT_FALSE(math_called);
1653   }
1654 
1655   {
1656     backend_called = math_called = false;
1657     // AutogradOther is fallthrough, calls SparseCPU kernel
1658     callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true));
1659     ASSERT_FALSE(math_called);
1660     ASSERT_TRUE(backend_called);
1661   }
1662 }
1663 
TEST(NewOperatorRegistrationTest,BackendOverridesCompositeExplicitAutogradKernel)1664 TEST(NewOperatorRegistrationTest, BackendOverridesCompositeExplicitAutogradKernel) {
1665   bool default_called = false;
1666   bool backend_called = false;
1667   auto m = MAKE_TORCH_LIBRARY(test);
1668   m.def("fn", torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, [&](const Tensor& x) { default_called = true; return x; }));
1669   m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { backend_called = true; return x; });
1670 
1671   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1672   ASSERT_TRUE(op.has_value());
1673 
1674   {
1675     default_called = backend_called = false;
1676     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1677     ASSERT_TRUE(backend_called);
1678     ASSERT_FALSE(default_called);
1679   }
1680 
1681   {
1682     default_called = backend_called = false;
1683     // AutogradCPU is fallthrough, calls CPU kernel
1684     callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1685     ASSERT_TRUE(backend_called);
1686     ASSERT_FALSE(default_called);
1687   }
1688 
1689   {
1690     default_called = backend_called = false;
1691     callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
1692     ASSERT_TRUE(default_called);
1693     ASSERT_FALSE(backend_called);
1694   }
1695 
1696   {
1697     default_called = backend_called = false;
1698     // AutogradCUDA is fallthrough, calls CUDA kernel
1699     callOp(*op, dummyTensor(c10::DispatchKey::CUDA, /*requires_grad=*/true));
1700     ASSERT_TRUE(default_called);
1701     ASSERT_FALSE(backend_called);
1702   }
1703 }
1704 
1705 
TEST(NewOperatorRegistrationTest,dispatch)1706 TEST(NewOperatorRegistrationTest, dispatch) {
1707   bool cpu_called = false;
1708   bool cuda_called = false;
1709   bool autograd_called = false;
1710   auto m = MAKE_TORCH_LIBRARY(test);
1711   m.def("fn_cpu", torch::dispatch(c10::DispatchKey::CPU, [&](const Tensor& x) { cpu_called = true; return x; }));
1712   m.def("fn_cuda", torch::dispatch(c10::kCUDA, [&](const Tensor& x) { cuda_called = true; return x; }));
1713   m.def("fn_autograd", torch::dispatch(c10::kAutograd, [&](const Tensor& x) { autograd_called = true; return x; }));
1714 
1715   {
1716     auto op = Dispatcher::singleton().findSchema({"test::fn_cpu", ""});
1717     ASSERT_TRUE(op.has_value());
1718     ASSERT_FALSE(cpu_called);
1719     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1720     ASSERT_TRUE(cpu_called);
1721   }
1722 
1723   {
1724     auto op = Dispatcher::singleton().findSchema({"test::fn_cuda", ""});
1725     ASSERT_TRUE(op.has_value());
1726     ASSERT_FALSE(cuda_called);
1727     callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
1728     ASSERT_TRUE(cuda_called);
1729   }
1730 
1731   {
1732     auto op = Dispatcher::singleton().findSchema({"test::fn_autograd", ""});
1733     ASSERT_TRUE(op.has_value());
1734     ASSERT_FALSE(autograd_called);
1735     callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1736     ASSERT_TRUE(autograd_called);
1737   }
1738 
1739   for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1740     autograd_called = false;
1741     auto op = Dispatcher::singleton().findSchema({"test::fn_autograd", ""});
1742     ASSERT_TRUE(op.has_value());
1743     callOp(*op, dummyTensor(key, /*requires_grad=*/true));
1744     ASSERT_TRUE(autograd_called);
1745   }
1746 }
1747 
TEST(NewOperatorRegistrationTest,dispatchAutogradPrecedence)1748 TEST(NewOperatorRegistrationTest, dispatchAutogradPrecedence) {
1749   bool cpu_called = false;
1750   auto m = MAKE_TORCH_LIBRARY(test);
1751   m.def("fn", torch::dispatch(c10::DispatchKey::CPU, [&](const Tensor& x) { cpu_called = true; return x; }));
1752 
1753   {
1754     auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1755     ASSERT_TRUE(op.has_value());
1756     ASSERT_FALSE(cpu_called);
1757     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1758     ASSERT_TRUE(cpu_called);
1759   }
1760 
1761   {
1762     // AutogradCPU is fallthrough, use CPU kernel
1763     cpu_called = false;
1764     auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1765     callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1766     ASSERT_TRUE(cpu_called);
1767   }
1768 
1769   bool autograd_called = false;
1770   m.impl("fn", c10::kAutograd, [&](const Tensor& x) { autograd_called = true; return x; });
1771 
1772   {
1773     auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1774     callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1775     ASSERT_TRUE(autograd_called);
1776   }
1777 
1778   // Autograd backend kernel has higher precedence than Autograd alias.
1779   bool autogradcpu_called = false;
1780   m.impl("fn", c10::DispatchKey::AutogradCPU, [&](const Tensor& x) { autogradcpu_called = true; return x; });
1781 
1782   {
1783     auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1784     callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1785     ASSERT_TRUE(autogradcpu_called);
1786   }
1787 }
1788 
TEST(NewOperatorRegistrationTest,throwsWhenRegisterToBackendMapsToAutogradOther)1789 TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther) {
1790   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1791   bool fpga_called, math_called = false;
1792   auto m = MAKE_TORCH_LIBRARY(test);
1793   m.def("fn", torch::dispatch(c10::DispatchKey::FPGA, [&](const Tensor& x) { fpga_called = true; return x; }));
1794   m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; });
1795 
1796   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1797   ASSERT_TRUE(op.has_value());
1798 
1799   {
1800     callOp(*op, dummyTensor(c10::DispatchKey::FPGA));
1801     ASSERT_TRUE(fpga_called);
1802   }
1803 
1804   {
1805     expectThrows<c10::Error>([&] {
1806       callOp(*op, dummyTensor(c10::DispatchKey::FPGA, /*requires_grad=*/true));
1807     }, "test::fn has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther.");
1808   }
1809 }
1810 
TEST(NewOperatorRegistrationTest,dispatchMultipleTensors)1811 TEST(NewOperatorRegistrationTest, dispatchMultipleTensors) {
1812   bool privateuse1_called = false;
1813   bool catchall_called = false;
1814   // Similar to in-tree AutogradCPU/AutogradCUDA etc, out-of-tree backends usually register
1815   // a fallthrough kernel for AutogradPrivateUse1.
1816   auto m1 = MAKE_TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1);
1817   m1.fallback(CppFunction::makeFallthrough());
1818 
1819   auto m = MAKE_TORCH_LIBRARY(test);
1820   m.def("fn", torch::dispatch(c10::DispatchKey::PrivateUse1, [&](const Tensor& x, const Tensor& y) { privateuse1_called = true; return x; }));
1821   m.impl("fn", [&](const Tensor& x, const Tensor& y) { catchall_called = true; return x; });
1822 
1823   {
1824     auto op = Dispatcher::singleton().findOp({"test::fn", ""});
1825     ASSERT_TRUE(op.has_value());
1826     callOp(*op, dummyTensor(c10::DispatchKey::PrivateUse1), dummyTensor(c10::DispatchKey::CPU));
1827     ASSERT_TRUE(privateuse1_called);
1828   }
1829 
1830   {
1831     auto op = Dispatcher::singleton().findOp({"test::fn", ""});
1832     ASSERT_TRUE(op.has_value());
1833     ASSERT_FALSE(catchall_called);
1834     callOp(*op, dummyTensor(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CPU));
1835     ASSERT_TRUE(catchall_called);
1836   }
1837 
1838   {
1839     auto op = Dispatcher::singleton().findOp({"test::fn", ""});
1840     ASSERT_TRUE(op.has_value());
1841     catchall_called = false;
1842     callOp(*op,
1843            dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true),
1844            dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1845     ASSERT_TRUE(catchall_called);
1846   }
1847 
1848   {
1849     auto op = Dispatcher::singleton().findOp({"test::fn", ""});
1850     ASSERT_TRUE(op.has_value());
1851     catchall_called = false;
1852     privateuse1_called = false;
1853     callOp(*op,
1854            dummyTensor(c10::DispatchKey::PrivateUse1, /*requires_grad=*/true),
1855            dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1856     ASSERT_FALSE(catchall_called);
1857     ASSERT_TRUE(privateuse1_called);
1858   }
1859 
1860   m.impl("fn", c10::DispatchKey::AutogradPrivateUse1, [&](const Tensor& x, const Tensor& y) { privateuse1_called = true; return x; });
1861 
1862   {
1863     auto op = Dispatcher::singleton().findOp({"test::fn", ""});
1864     ASSERT_TRUE(op.has_value());
1865     privateuse1_called = false;
1866     callOp(*op,
1867            dummyTensor(c10::DispatchKey::PrivateUse1, /*requires_grad=*/true),
1868            dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1869     ASSERT_TRUE(privateuse1_called);
1870   }
1871 }
1872 
TEST(NewOperatorRegistrationTest,registerCompositeImplicitAutogradWithCPUKernel_andCallAutogradOtherKernel_callsComposite)1873 TEST(NewOperatorRegistrationTest, registerCompositeImplicitAutogradWithCPUKernel_andCallAutogradOtherKernel_callsComposite) {
1874   bool math_called = false;
1875   bool cpu_called = false;
1876   auto m = MAKE_TORCH_LIBRARY(test);
1877   m.def("fn(Tensor dummy) -> Tensor");
1878   m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { cpu_called = true; return x; });
1879   m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; });
1880 
1881   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1882   ASSERT_TRUE(op.has_value());
1883 
1884   {
1885     math_called = cpu_called = false;
1886     // Meta should redispatch to the AutogradOther backend,
1887     // which the composite kernel should be registered to.
1888     callOp(*op, dummyTensor(c10::DispatchKey::Meta, /*requires_grad=*/true));
1889     ASSERT_TRUE(math_called);
1890     ASSERT_FALSE(cpu_called);
1891   }
1892 }
1893 
TEST(NewOperatorRegistrationTest,dispatchMultiple)1894 TEST(NewOperatorRegistrationTest, dispatchMultiple) {
1895   bool cpu_called = false;
1896   bool cuda_called = false;
1897   bool autograd_called = false;
1898   auto m = MAKE_TORCH_LIBRARY(test);
1899   m.def("fn(Tensor self) -> Tensor");
1900   // NB: Direct use of DispatchKey is discouraged; use the DeviceType
1901   // k-synonyms instead
1902   m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { cpu_called = true; return x; });
1903   m.impl("fn", c10::kCUDA, [&](const Tensor& x) { cuda_called = true; return x; });
1904   m.impl("fn", c10::kAutograd, [&](const Tensor& x) { autograd_called = true; return x; });
1905 
1906   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1907   ASSERT_TRUE(op.has_value());
1908 
1909   {
1910     ASSERT_FALSE(cpu_called);
1911     callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1912     ASSERT_TRUE(cpu_called);
1913 
1914     ASSERT_FALSE(cuda_called);
1915     callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
1916     ASSERT_TRUE(cuda_called);
1917   }
1918 
1919   {
1920     ASSERT_FALSE(autograd_called);
1921     callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1922     ASSERT_TRUE(autograd_called);
1923 
1924     autograd_called = false;
1925     callOp(*op, dummyTensor(c10::DispatchKey::CUDA, /*requires_grad=*/true));
1926     ASSERT_TRUE(autograd_called);
1927   }
1928 }
1929 
TEST(NewOperatorRegistrationTest,fallback)1930 TEST(NewOperatorRegistrationTest, fallback) {
1931   auto m = MAKE_TORCH_LIBRARY_IMPL(_, CPU);
1932   m.fallback(CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
1933 
1934   auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()");
1935 
1936   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
1937   ASSERT_TRUE(op.has_value());
1938   auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
1939   EXPECT_EQ("hello _test::dummy", stack[1].toStringRef());
1940 }
1941 
TEST(NewOperatorRegistrationTest,BackendSelectRedispatchesToCPU)1942 TEST(NewOperatorRegistrationTest, BackendSelectRedispatchesToCPU) {
1943   bool cpu_called = false;
1944   bool backend_generic_called = false;
1945   auto m = MAKE_TORCH_LIBRARY(test);
1946   auto after_backend_select = c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::BackendSelect);
1947   m.def("fn(Tensor self) -> Tensor");
1948   m.impl("fn", c10::kCPU, [&](const Tensor& x) { cpu_called = true; return x; });
1949   m.impl("fn", c10::DispatchKey::BackendSelect, [&](c10::DispatchKeySet ks, const Tensor& x) {
1950      backend_generic_called = true;
1951      auto op = c10::Dispatcher::singleton().findSchema({"test::fn", ""}).value().typed<Tensor (const Tensor&)>();
1952      return c10::Dispatcher::singleton().redispatch<Tensor, const Tensor&>(op, ks & after_backend_select, x);
1953    });
1954 
1955   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1956   ASSERT_TRUE(op.has_value());
1957   callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1958   ASSERT_TRUE(cpu_called);
1959   ASSERT_TRUE(backend_generic_called);
1960 }
1961 
TEST(NewOperatorRegistrationTest,TorchLibraryTwiceIsError)1962 TEST(NewOperatorRegistrationTest, TorchLibraryTwiceIsError) {
1963   {
1964     auto m = MAKE_TORCH_LIBRARY(test);
1965     expectThrows<c10::Error>([] {
1966       auto m2 = MAKE_TORCH_LIBRARY(test);
1967     }, "Only a single TORCH_LIBRARY");
1968   }
1969   // Ensure it's ok after deregistering
1970   auto m = MAKE_TORCH_LIBRARY(test);
1971 }
1972 
dummy_fn(const Tensor & x)1973 Tensor dummy_fn(const Tensor& x) {
1974   return x;
1975 }
1976 
TEST(NewOperatorRegistrationTest,CppFunction)1977 TEST(NewOperatorRegistrationTest, CppFunction) {
1978   // Just show off the possible ways to register functions
1979   auto m = MAKE_TORCH_LIBRARY(test);
1980   m.def("fn1", &dummy_fn);
1981   // C++ will implicitly convert function to function pointer
1982   // c.f. https://en.cppreference.com/w/cpp/language/implicit_conversion#Function_to_pointer
1983   m.def("fn2", dummy_fn);
1984   m.def("fn3", [](const Tensor& x) { return x; });
1985   // These require explicit schema
1986   m.def("fn4(Tensor x) -> Tensor", CppFunction::makeFallthrough());
1987   m.def("fn5(Tensor x) -> Tensor", CppFunction::makeFromUnboxedFunction(dummy_fn));
1988   m.def("fn6(Tensor x) -> Tensor", CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
1989 }
1990 
1991 // Some internal tests that have to be done from C++
1992 
1993 struct OpRegistrationListenerForDelayedListenerTest : public c10::OpRegistrationListener {
1994   int64_t num_registers_ = 0;
1995   int64_t num_deregisters_ = 0;
onOperatorRegistered__anondbb8b74d0111::OpRegistrationListenerForDelayedListenerTest1996   void onOperatorRegistered(const OperatorHandle& op) override {
1997     num_registers_++;
1998   }
onOperatorDeregistered__anondbb8b74d0111::OpRegistrationListenerForDelayedListenerTest1999   void onOperatorDeregistered(const OperatorHandle& op) override {
2000     num_deregisters_++;
2001   }
2002 };
2003 
TEST(NewOperatorRegistrationTest,testDelayedListener)2004 TEST(NewOperatorRegistrationTest, testDelayedListener) {
2005   auto listener = std::make_unique<OpRegistrationListenerForDelayedListenerTest>();
2006   auto listener_ptr = listener.get();
2007   auto registry = Dispatcher::singleton().addRegistrationListener(std::move(listener));
2008   int64_t initial_num_registers = listener_ptr->num_registers_;
2009   int64_t initial_num_deregisters = listener_ptr->num_deregisters_;
2010   auto op = Dispatcher::singleton().findOp({"_test::dummy", ""});
2011   ASSERT_FALSE(op.has_value());
2012   auto m1 = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
2013   m1.impl("dummy", [](const Tensor& self) { return self; });
2014   EXPECT_EQ(initial_num_registers, listener_ptr->num_registers_);
2015   {
2016     auto m2 = MAKE_TORCH_LIBRARY(_test);
2017     m2.def("dummy(Tensor self) -> Tensor");
2018     EXPECT_EQ(initial_num_registers + 1, listener_ptr->num_registers_);
2019   }
2020   EXPECT_EQ(initial_num_deregisters + 1, listener_ptr->num_deregisters_);
2021 }
2022 
TEST(NewOperatorRegistrationTest,testImplNoDefGetsCaught)2023 TEST(NewOperatorRegistrationTest, testImplNoDefGetsCaught) {
2024   auto danglingImpls = Dispatcher::singleton().findDanglingImpls();
2025   std::string error_str = "Discovered operators that have been registered through the dispatcher"
2026                           " without explicitly specifying their schemas. Please do so using"
2027                           " the TORCH_LIBRARY macro. Suspect operators:\n";
2028   for (auto& op : danglingImpls) {
2029       auto& op_name = op.operator_name();
2030       error_str += "\t" + op_name.name;
2031       if (op_name.overload_name != "") {
2032           error_str += "." + op_name.overload_name;
2033       }
2034       error_str += "\n";
2035   }
2036   ASSERT_EQ(danglingImpls.size(), 0) << error_str;
2037 }
2038 
2039 bool called_kernel_cpu = false;
2040 bool called_kernel_autograd = false;
2041 bool called_kernel_tracing = false;
2042 
cpu_kernel(Tensor)2043 void cpu_kernel(Tensor) {
2044   called_kernel_cpu = true;
2045 }
2046 
2047 // autograd kernel that redispatches. Explicitly takes in and updates the DispatchKeySet
autograd_kernel_redispatching_with_DispatchKeySet(c10::DispatchKeySet ks,Tensor a)2048 void autograd_kernel_redispatching_with_DispatchKeySet(c10::DispatchKeySet ks, Tensor a) {
2049   called_kernel_autograd = true;
2050   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
2051   auto updatedDispatchKeySet = ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther);
2052   callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, updatedDispatchKeySet, a);
2053 }
2054 
2055 // autograd kernel that redispatches. Does not take in a DispatchKeySet
autograd_kernel_redispatching_without_DispatchKeySet(c10::DispatchKeySet ks,Tensor a)2056 void autograd_kernel_redispatching_without_DispatchKeySet(c10::DispatchKeySet ks, Tensor a) {
2057   called_kernel_autograd = true;
2058   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
2059   auto updatedDispatchKeySet = ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther);
2060   callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, updatedDispatchKeySet, a);
2061 }
2062 
2063 // tracing kernel that redispatches. Explicitly takes in and updates the DispatchKeySet
tracing_kernel_redispatching_with_DispatchKeySet(c10::DispatchKeySet ks,Tensor a)2064 void tracing_kernel_redispatching_with_DispatchKeySet(c10::DispatchKeySet ks, Tensor a) {
2065   called_kernel_tracing = true;
2066   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
2067   auto updatedDispatchKeySet = ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer);
2068   callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, updatedDispatchKeySet, a);
2069 }
2070 
TEST(OperatorRegistrationTest,callKernelsWithDispatchKeySetConvention_call_redispatchesToLowerPriorityKernels)2071 TEST(OperatorRegistrationTest, callKernelsWithDispatchKeySetConvention_call_redispatchesToLowerPriorityKernels) {
2072   auto m = MAKE_TORCH_LIBRARY(test);
2073   m.def("fn(Tensor dummy) -> ()");
2074   m.impl("fn", c10::DispatchKey::CPU, cpu_kernel);
2075   m.impl("fn", c10::DispatchKey::AutogradCPU, autograd_kernel_redispatching_with_DispatchKeySet);
2076   m.impl("fn", c10::DispatchKey::Tracer, tracing_kernel_redispatching_with_DispatchKeySet);
2077 
2078   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
2079   ASSERT_TRUE(op.has_value());
2080 
2081   called_kernel_cpu = called_kernel_autograd = called_kernel_tracing = false;
2082   auto tracing_autograd_cpu_set = c10::DispatchKeySet()
2083                                     .add(c10::DispatchKey::Tracer)
2084                                     .add(c10::DispatchKey::AutogradCPU)
2085                                     .add(c10::DispatchKey::CPU);
2086 
2087   // call Tracing -> call Autograd -> call CPU
2088   callOpUnboxed<void, Tensor>(*op, dummyTensor(tracing_autograd_cpu_set, true));
2089   EXPECT_TRUE(called_kernel_tracing);
2090   EXPECT_TRUE(called_kernel_autograd);
2091   EXPECT_TRUE(called_kernel_cpu);
2092 }
2093 
TEST(OperatorRegistrationTest,callKernelsWithDispatchKeySetConvention_callBoxed_redispatchesToLowerPriorityKernels)2094 TEST(OperatorRegistrationTest, callKernelsWithDispatchKeySetConvention_callBoxed_redispatchesToLowerPriorityKernels) {
2095   auto m = MAKE_TORCH_LIBRARY(test);
2096   m.def("fn(Tensor dummy) -> ()");
2097   m.impl("fn", c10::DispatchKey::CPU, cpu_kernel);
2098   m.impl("fn", c10::DispatchKey::AutogradCPU, autograd_kernel_redispatching_with_DispatchKeySet);
2099   m.impl("fn", c10::DispatchKey::Tracer, tracing_kernel_redispatching_with_DispatchKeySet);
2100 
2101   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
2102   ASSERT_TRUE(op.has_value());
2103 
2104   called_kernel_cpu = called_kernel_autograd = called_kernel_tracing = false;
2105   auto tracing_autograd_cpu_set = c10::DispatchKeySet()
2106                                     .add(c10::DispatchKey::Tracer)
2107                                     .add(c10::DispatchKey::AutogradCPU)
2108                                     .add(c10::DispatchKey::CPU);
2109 
2110   // call Tracing -> call Autograd -> call CPU
2111   callOp<Tensor>(*op, dummyTensor(tracing_autograd_cpu_set, true));
2112   EXPECT_TRUE(called_kernel_tracing);
2113   EXPECT_TRUE(called_kernel_autograd);
2114   EXPECT_TRUE(called_kernel_cpu);
2115 }
2116 
TEST(OperatorRegistrationTest,callKernelsWithDispatchKeySetConvention_mixedCallingConventions_redispatchesToLowerPriorityKernels)2117 TEST(OperatorRegistrationTest, callKernelsWithDispatchKeySetConvention_mixedCallingConventions_redispatchesToLowerPriorityKernels) {
2118   auto m = MAKE_TORCH_LIBRARY(test);
2119   m.def("fn(Tensor dummy) -> ()");
2120   m.impl("fn", c10::DispatchKey::CPU, cpu_kernel);
2121   // the tracing kernel takes in a DispatchKeySet, but the autograd kernel does not
2122   // the dispatcher should handle correctly plumbing its DispatchKeySet to tracing and not autograd.
2123   m.impl("fn", c10::DispatchKey::AutogradCPU, autograd_kernel_redispatching_without_DispatchKeySet);
2124   m.impl("fn", c10::DispatchKey::Tracer, tracing_kernel_redispatching_with_DispatchKeySet);
2125 
2126   auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
2127   ASSERT_TRUE(op.has_value());
2128 
2129   called_kernel_cpu = called_kernel_autograd = called_kernel_tracing = false;
2130   auto tracing_autograd_cpu_set = c10::DispatchKeySet()
2131                                     .add(c10::DispatchKey::Tracer)
2132                                     .add(c10::DispatchKey::AutogradCPU)
2133                                     .add(c10::DispatchKey::CPU);
2134 
2135   // call Tracing -> call Autograd -> call CPU
2136   callOpUnboxed<void, Tensor>(*op, dummyTensor(tracing_autograd_cpu_set, true));
2137   EXPECT_TRUE(called_kernel_tracing);
2138   EXPECT_TRUE(called_kernel_autograd);
2139   EXPECT_TRUE(called_kernel_cpu);
2140 }
2141 
TEST(OperatorRegistrationTest,getRegistrationsForDispatchKey)2142 TEST(OperatorRegistrationTest, getRegistrationsForDispatchKey) {
2143   // should return every registered op
2144   auto all_ops = Dispatcher::singleton().getRegistrationsForDispatchKey(std::nullopt);
2145   // should return every registered op with a cpu kernel
2146   auto cpu_ops = Dispatcher::singleton().getRegistrationsForDispatchKey(c10::DispatchKey::CPU);
2147   ASSERT_TRUE(all_ops.size() > 0);
2148   ASSERT_TRUE(cpu_ops.size() > 0);
2149 
2150   auto cmp_lambda = [](const c10::OperatorName a, const c10::OperatorName& b) -> bool {
2151       return c10::toString(a) < c10::toString(b);
2152   };
2153 
2154   std::sort(all_ops.begin(), all_ops.end(), cmp_lambda);
2155   std::sort(cpu_ops.begin(), cpu_ops.end(), cmp_lambda);
2156   ASSERT_TRUE(std::includes(all_ops.begin(), all_ops.end(), cpu_ops.begin(), cpu_ops.end(), cmp_lambda));
2157 }
2158 
symint_op(const Tensor & self,int64_t length)2159 Tensor symint_op(const Tensor& self, int64_t length) {
2160   return self.clone();
2161 }
2162 
TEST(OperatorRegistrationTest,TestSymNonSymCompatibility)2163 TEST(OperatorRegistrationTest, TestSymNonSymCompatibility) {
2164   auto m = MAKE_TORCH_LIBRARY(_test);
2165   m.def("_test::symint_op(Tensor self, SymInt length) -> Tensor");
2166   auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
2167   m_cpu.impl("symint_op", c10::DispatchKey::CPU, TORCH_FN(symint_op));
2168 
2169   auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
2170       "_test::symint_op", "");
2171 
2172   opHandle.typed<Tensor(const Tensor&, int64_t)>().call(dummyTensor(c10::DispatchKey::CPU), 4);
2173   opHandle.typed<Tensor(const Tensor&, c10::SymInt)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
2174 
2175   expectThrows<c10::Error>([&] {
2176     opHandle.typed<Tensor(const Tensor&, const c10::SymInt&)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
2177   }, "Tried to access or call an operator with a wrong signature");
2178 }
2179 
symint_op2(const Tensor & self,c10::SymInt length)2180 Tensor symint_op2(const Tensor& self, c10::SymInt length) {
2181   return self.clone();
2182 }
2183 
TEST(OperatorRegistrationTest,TestSymSymCompatibility)2184 TEST(OperatorRegistrationTest, TestSymSymCompatibility) {
2185   auto m = MAKE_TORCH_LIBRARY(_test);
2186   m.def("_test::symint_op(Tensor self, SymInt length) -> Tensor");
2187   auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
2188   m_cpu.impl("symint_op", c10::DispatchKey::CPU, TORCH_FN(symint_op2));
2189 
2190   auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
2191       "_test::symint_op", "");
2192 
2193   opHandle.typed<Tensor(const Tensor&, int64_t)>().call(dummyTensor(c10::DispatchKey::CPU), 4);
2194   opHandle.typed<Tensor(const Tensor&, c10::SymInt)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
2195   // TODO: We should reject this on principle, but today it accidentally works
2196   // due to going through the boxed calling convention.
2197   //
2198   // First, we attempt to test if const SymInt& has SymInt. It does not,
2199   // because we only accept something as SymInt if it has exactly SymInt in
2200   // its signature. So we check if there is a non-symint kernel. But there is
2201   // no non-SymInt kernel, because we only registered a real SymInt kernel.
2202   // When this occurs, we fall back to the boxed calling convention.  And the
2203   // boxed calling convention can deal with const SymInt& fine, as during
2204   // boxing it will just create a SymInt to push onto the argument stack and
2205   // everything is fine.
2206   opHandle.typed<Tensor(const Tensor&, const c10::SymInt&)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
2207 }
2208 
symint_op3(const Tensor & self,const c10::SymInt & length)2209 Tensor symint_op3(const Tensor& self, const c10::SymInt& length) {
2210   return self.clone();
2211 }
2212 
TEST(OperatorRegistrationTest,TestSymSymRefCompatibility)2213 TEST(OperatorRegistrationTest, TestSymSymRefCompatibility) {
2214   auto m = MAKE_TORCH_LIBRARY(_test);
2215   m.def("_test::symint_op(Tensor self, SymInt length) -> Tensor");
2216   auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
2217 
2218   expectThrows<c10::Error>([&] {
2219     m_cpu.impl("symint_op", c10::DispatchKey::CPU, TORCH_FN(symint_op3));
2220   }, "doesn't match the expected function schema");
2221 }
2222 
2223 }
2224 
2225 #pragma GCC diagnostic pop
2226