1 /**
2 * This file contains some general registration test cases.
3 * More detailed test cases containing different APIs for registering kernels
4 * are found in other files in this directory.
5 */
6
7 #include <gtest/gtest.h>
8
9 // This file intentionally tests some deprecated APIs
10 #pragma GCC diagnostic push
11 #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
12
13 #include <ATen/core/boxing/impl/test_helpers.h>
14 #include <ATen/core/op_registration/op_registration.h>
15 #include <torch/library.h>
16 #include <ATen/core/Tensor.h>
17 #include <functional>
18
19 #include <ATen/core/LegacyTypeDispatch.h>
20
21 #include <algorithm>
22
23 using c10::RegisterOperators;
24 using c10::OperatorKernel;
25 using c10::OperatorHandle;
26 using c10::Dispatcher;
27 using c10::IValue;
28 using c10::DispatchKey;
29
30 using torch::Library;
31 using torch::CppFunction;
32
33 using at::Tensor;
34
35 namespace {
36
37 struct DummyKernel final : OperatorKernel {
operator ()__anondbb8b74d0111::DummyKernel38 void operator()(Tensor) {}
39 };
40
41 struct MockKernel final : OperatorKernel {
MockKernel__anondbb8b74d0111::MockKernel42 MockKernel(bool* called): called_(called) {}
43
operator ()__anondbb8b74d0111::MockKernel44 void operator()(Tensor) {
45 *called_ = true;
46 }
47 private:
48 bool* called_;
49 };
50
TEST(OperatorRegistrationTest,whenRegisteringWithSchemaBeforeKernelInOptionsObject_thenCanBeCalled)51 TEST(OperatorRegistrationTest, whenRegisteringWithSchemaBeforeKernelInOptionsObject_thenCanBeCalled) {
52 bool called = false;
53 auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().schema("_test::dummy(Tensor dummy) -> ()").catchAllKernel<MockKernel>(&called));
54
55 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
56 ASSERT_TRUE(op.has_value());
57 EXPECT_FALSE(called);
58 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
59 EXPECT_TRUE(called);
60 }
61
TEST(OperatorRegistrationTest,whenRegisteringWithSchemaAfterKernelInOptionsObject_thenCanBeCalled)62 TEST(OperatorRegistrationTest, whenRegisteringWithSchemaAfterKernelInOptionsObject_thenCanBeCalled) {
63 bool called = false;
64 auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called).schema("_test::dummy(Tensor dummy) -> ()"));
65
66 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
67 ASSERT_TRUE(op.has_value());
68 EXPECT_FALSE(called);
69 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
70 EXPECT_TRUE(called);
71 }
72
TEST(OperatorRegistrationTest,whenRegisteringWithNameBeforeKernelInOptionsObject_thenCanBeCalled)73 TEST(OperatorRegistrationTest, whenRegisteringWithNameBeforeKernelInOptionsObject_thenCanBeCalled) {
74 bool called = false;
75 auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().schema("_test::dummy").catchAllKernel<MockKernel>(&called));
76
77 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
78 ASSERT_TRUE(op.has_value());
79 EXPECT_FALSE(called);
80 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
81 EXPECT_TRUE(called);
82 }
83
TEST(OperatorRegistrationTest,whenRegisteringWithNameAfterKernelInOptionsObject_thenCanBeCalled)84 TEST(OperatorRegistrationTest, whenRegisteringWithNameAfterKernelInOptionsObject_thenCanBeCalled) {
85 bool called = false;
86 auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called).schema("_test::dummy"));
87
88 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
89 ASSERT_TRUE(op.has_value());
90 EXPECT_FALSE(called);
91 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
92 EXPECT_TRUE(called);
93 }
94
TEST(OperatorRegistrationTest,whenRegisteringWithoutSchema_thenFails)95 TEST(OperatorRegistrationTest, whenRegisteringWithoutSchema_thenFails) {
96 expectThrows<c10::Error>([] {
97 c10::RegisterOperators().op(c10::RegisterOperators::options().catchAllKernel<DummyKernel>());
98 }, "In operator registration: Tried to register an operator without specifying a schema or operator name.");
99 }
100
TEST(OperatorRegistrationTest,whenCallingOpWithWrongDispatchKey_thenFails)101 TEST(OperatorRegistrationTest, whenCallingOpWithWrongDispatchKey_thenFails) {
102 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::DispatchKey::CPU));
103
104 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
105 ASSERT_TRUE(op.has_value());
106 expectThrows<c10::Error>([&] {
107 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
108 }, "Could not run '_test::dummy' with arguments from the 'CUDA'"
109 " backend.");
110 }
111
TEST(OperatorRegistrationTest,givenOpWithCatchallKernel_whenCallingOp_thenCallsCatchallKernel)112 TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenCallingOp_thenCallsCatchallKernel) {
113 bool called = false;
114 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
115
116 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
117 ASSERT_TRUE(op.has_value());
118 EXPECT_FALSE(called);
119 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
120 EXPECT_TRUE(called);
121 }
122
123 // TODO Rewrite (since this is now allowed) and reenable
124 // TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenRegisteringDispatchedKernel_thenFails) {
125 // bool called = false;
126 // auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
127 // expectThrows<c10::Error>([&] {
128 // c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(c10::DispatchKey::CPU, &called));
129 // }, "for an operator which already has a catch-all kernel registered");
130 // }
131
132 // TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenRegisteringDispatchedKernelInSameOpCall_thenFails) {
133 // bool called = false;
134 // expectThrows<c10::Error>([&] {
135 // auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
136 // .catchAllKernel<MockKernel>(&called)
137 // .kernel<MockKernel>(c10::DispatchKey::CPU, &called));
138 // }, "for an operator which already has a catch-all kernel registered");
139 // }
140
TEST(OperatorRegistrationTest,givenOpWithDispatchedKernelOutOfScope_whenRegisteringCatchallKernelAndCallingOp_thenCallsCatchallKernel)141 TEST(OperatorRegistrationTest, givenOpWithDispatchedKernelOutOfScope_whenRegisteringCatchallKernelAndCallingOp_thenCallsCatchallKernel) {
142 bool called = false;
143 {
144 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(c10::DispatchKey::CPU, &called));
145 }
146
147 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
148
149 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
150 ASSERT_TRUE(op.has_value());
151 EXPECT_FALSE(called);
152 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
153 EXPECT_TRUE(called);
154 }
155
156 // TODO Rewrite (since this is now allowed) and reenable
157 // TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernel_thenFails) {
158 // bool called = false;
159 // auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(c10::DispatchKey::CPU, &called));
160 // expectThrows<c10::Error>([&] {
161 // c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
162 // }, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPU. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy");
163 // }
164 //
165 // TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernelInSameOpCall_thenFails) {
166 // bool called = false;
167 // expectThrows<c10::Error>([&] {
168 // auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
169 // .kernel<MockKernel>(c10::DispatchKey::CPU, &called)
170 // .catchAllKernel<MockKernel>(&called));
171 // }, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPU. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy");
172 // }
173
TEST(OperatorRegistrationTest,givenOpWithCatchallKernelOutOfScope_whenRegisteringDispatchedKernelAndCallingOp_thenCallsCatchallKernel)174 TEST(OperatorRegistrationTest, givenOpWithCatchallKernelOutOfScope_whenRegisteringDispatchedKernelAndCallingOp_thenCallsCatchallKernel) {
175 bool called = false;
176 {
177 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
178 }
179
180 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(c10::DispatchKey::CPU, &called));
181
182 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
183 ASSERT_TRUE(op.has_value());
184 EXPECT_FALSE(called);
185 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
186 EXPECT_TRUE(called);
187 }
188
TEST(OperatorRegistrationTest,givenOpWithoutKernels_whenRegisteringWithSchema_thenOnlyRegistersSchema)189 TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringWithSchema_thenOnlyRegistersSchema) {
190 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()");
191
192 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
193 ASSERT_TRUE(op.has_value()); // assert schema is registered
194 expectThrows<c10::Error>([&] {
195 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
196 }, "Could not run '_test::dummy' with arguments from the 'CPU'"
197 " backend.");
198 }
199
TEST(OperatorRegistrationTest,givenOpWithoutKernels_whenRegisteringWithoutSchema_thenFails)200 TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringWithoutSchema_thenFails) {
201 expectThrows<c10::Error>([&] {
202 c10::RegisterOperators().op("_test::dummy");
203 }, "Cannot infer operator schema in registration of operator _test::dummy because there is no kernel specified.");
204 }
205
TEST(OperatorRegistrationTest,givenOpWithoutKernels_whenRunningOutOfScope_thenSchemaIsGone)206 TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRunningOutOfScope_thenSchemaIsGone) {
207 {
208 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()");
209 }
210
211 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
212 EXPECT_FALSE(op.has_value());
213 }
214
TEST(OperatorRegistrationTest,givenOpWithoutKernelsWithoutTensorInputs_whenRegistering_thenRegisters)215 TEST(OperatorRegistrationTest, givenOpWithoutKernelsWithoutTensorInputs_whenRegistering_thenRegisters) {
216 // as long as we don't register non-catchall kernels, ops without tensor arguments are fine
217 auto registrar = c10::RegisterOperators().op("_test::dummy() -> ()");
218
219 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
220 ASSERT_TRUE(op.has_value()); // assert schema is registered
221 }
222
TEST(OperatorRegistrationTest,givenMultipleKernelsWithSameDispatchKey_whenRegisteringInSameOpCall_thenFails)223 TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegisteringInSameOpCall_thenFails) {
224 expectThrows<c10::Error>([&] {
225 auto registrar = c10::RegisterOperators()
226 .op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
227 .kernel<DummyKernel>(c10::DispatchKey::CPU)
228 .kernel<DummyKernel>(c10::DispatchKey::CPU));
229 }, "In operator registration: Tried to register multiple kernels with same dispatch key CPU for operator schema _test::dummy");
230 }
231
TEST(OperatorRegistrationTest,givenMultipleCatchallKernels_whenRegisteringInSameOpCall_thenFails)232 TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenRegisteringInSameOpCall_thenFails) {
233 expectThrows<c10::Error>([&] {
234 auto registrar = c10::RegisterOperators()
235 .op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
236 .catchAllKernel<DummyKernel>()
237 .catchAllKernel<DummyKernel>());
238 }, "Tried to register multiple catch-all kernels for operator schema _test::dummy");
239 }
240
TEST(OperatorRegistrationTest,whenRegisteringCPUTensorType_thenCanOnlyCallUnboxedWithCPUDispatchKey)241 TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnboxedWithCPUDispatchKey) {
242 bool called_kernel_cpu = false;
243 auto registrar= c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
244 .kernel<MockKernel>(c10::DispatchKey::CPU, &called_kernel_cpu));
245
246 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
247 ASSERT_TRUE(op.has_value()); // assert schema is registered
248
249 // Ensure that dispatcher doesn't take the dispatch key from the tensor but from the direct argument instead.
250 called_kernel_cpu = false;
251 callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA));
252 EXPECT_TRUE(called_kernel_cpu);
253
254 // Ensure that disptach key from tensor is not used here.
255 called_kernel_cpu = false;
256 expectThrows<c10::Error>([&] {
257 callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU));
258 }, "Could not run '_test::dummy' with arguments from the 'CUDA'"
259 " backend.");
260 }
261
expectedMessageForBackend(DispatchKey key)262 std::string expectedMessageForBackend(DispatchKey key) {
263 std::string key_str(c10::toString(key));
264 return "Could not run '_test::dummy' with arguments from the '" + key_str + "' backend";
265 }
266
TEST(OperatorRegistrationTest,whenRegisteringMultipleKernelsInSameOpCallAndCalling_thenCallsCorrectKernel)267 TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsInSameOpCallAndCalling_thenCallsCorrectKernel) {
268 bool called_kernel1 = false;
269 bool called_kernel2 = false;
270 auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
271 .kernel<MockKernel>(c10::DispatchKey::CPU, &called_kernel1)
272 .kernel<MockKernel>(c10::DispatchKey::CUDA, &called_kernel2));
273
274 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
275 ASSERT_TRUE(op.has_value()); // assert schema is registered
276
277 called_kernel1 = called_kernel2 = false;
278 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
279 EXPECT_TRUE(called_kernel1);
280 EXPECT_FALSE(called_kernel2);
281
282 called_kernel1 = called_kernel2 = false;
283 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
284 EXPECT_FALSE(called_kernel1);
285 EXPECT_TRUE(called_kernel2);
286
287 // Test for out of tree lazy backends- ::Lazy key is now registered to TS backend in tree
288 for (c10::DispatchKey key : {c10::DispatchKey::XLA}) {
289 std::string expectMessage = expectedMessageForBackend(key);
290 expectThrows<c10::Error>([&] {
291 callOp(*op, dummyTensor(key));
292 }, expectMessage.c_str());
293
294 // also assert that the error message contains the available tensor type ids, but don't assert their order
295 expectThrows<c10::Error>([&] {
296 callOp(*op, dummyTensor(key));
297 }, "CPU");
298 expectThrows<c10::Error>([&] {
299 callOp(*op, dummyTensor(key));
300 }, "CUDA");
301 }
302 }
303
304 bool called_stackbased_kernel = false;
stackBasedKernel(const OperatorHandle &,c10::Stack * stack)305 void stackBasedKernel(const OperatorHandle&, c10::Stack* stack) {
306 called_stackbased_kernel = true;
307 }
308
TEST(OperatorRegistrationTest,whenRegisteringMultipleKernelsByNameAndNoneCanInferSchema_thenFails)309 TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsByNameAndNoneCanInferSchema_thenFails) {
310 expectThrows<c10::Error>([&] {
311 auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
312 .kernel<&stackBasedKernel>(c10::DispatchKey::CPU)
313 .kernel<&stackBasedKernel>(c10::DispatchKey::CUDA)
314 .kernel<&stackBasedKernel>(c10::DispatchKey::XLA)
315 .kernel<&stackBasedKernel>(c10::DispatchKey::Lazy));
316 }, "Cannot infer operator schema for this kind of kernel in registration of operator _test::dummy");
317 }
318
TEST(OperatorRegistrationTest,whenRegisteringMultipleKernelsBySchemaAndNoneCanInferSchema_thenSucceeds)319 TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsBySchemaAndNoneCanInferSchema_thenSucceeds) {
320 bool called_kernel = false;
321 auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
322 .kernel<&stackBasedKernel>(c10::DispatchKey::CPU)
323 .kernel<&stackBasedKernel>(c10::DispatchKey::CUDA)
324 .kernel<&stackBasedKernel>(c10::DispatchKey::XLA)
325 .kernel<&stackBasedKernel>(c10::DispatchKey::Lazy));
326
327 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
328 ASSERT_TRUE(op.has_value()); // assert schema is registered
329
330 called_kernel = called_stackbased_kernel = false;
331 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
332 EXPECT_TRUE(called_stackbased_kernel);
333 EXPECT_FALSE(called_kernel);
334
335 called_kernel = called_stackbased_kernel = false;
336 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
337 EXPECT_TRUE(called_stackbased_kernel);
338 EXPECT_FALSE(called_kernel);
339
340 for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
341 called_kernel = called_stackbased_kernel = false;
342 callOp(*op, dummyTensor(key));
343 EXPECT_TRUE(called_stackbased_kernel);
344 EXPECT_FALSE(called_kernel);
345 }
346 }
347
TEST(OperatorRegistrationTest,whenRegisteringMultipleKernelsByNameAndOnlyOneCanInferSchema_thenSucceeds)348 TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsByNameAndOnlyOneCanInferSchema_thenSucceeds) {
349 bool called_kernel = false;
350 auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
351 .kernel<&stackBasedKernel>(c10::DispatchKey::CPU)
352 .kernel<MockKernel>(c10::DispatchKey::CUDA, &called_kernel)
353 .kernel<&stackBasedKernel>(c10::DispatchKey::XLA)
354 .kernel<&stackBasedKernel>(c10::DispatchKey::Lazy));
355
356 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
357 ASSERT_TRUE(op.has_value()); // assert schema is registered
358
359 called_kernel = called_stackbased_kernel = false;
360 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
361 EXPECT_TRUE(called_stackbased_kernel);
362 EXPECT_FALSE(called_kernel);
363
364 called_kernel = called_stackbased_kernel = false;
365 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
366 EXPECT_FALSE(called_stackbased_kernel);
367 EXPECT_TRUE(called_kernel);
368
369 for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
370 called_kernel = called_stackbased_kernel = false;
371 callOp(*op, dummyTensor(key));
372 EXPECT_TRUE(called_stackbased_kernel);
373 EXPECT_FALSE(called_kernel);
374 }
375 }
376
TEST(OperatorRegistrationTest,whenRegisteringMultipleKernelsBySchemaAndOnlyOneCanInferSchema_thenSucceeds)377 TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsBySchemaAndOnlyOneCanInferSchema_thenSucceeds) {
378 bool called_kernel = false;
379 auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
380 .kernel<&stackBasedKernel>(c10::DispatchKey::CPU)
381 .kernel<MockKernel>(c10::DispatchKey::CUDA, &called_kernel)
382 .kernel<&stackBasedKernel>(c10::DispatchKey::XLA)
383 .kernel<&stackBasedKernel>(c10::DispatchKey::Lazy));
384
385 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
386 ASSERT_TRUE(op.has_value()); // assert schema is registered
387
388 called_kernel = called_stackbased_kernel = false;
389 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
390 EXPECT_TRUE(called_stackbased_kernel);
391 EXPECT_FALSE(called_kernel);
392
393 called_kernel = called_stackbased_kernel = false;
394 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
395 EXPECT_FALSE(called_stackbased_kernel);
396 EXPECT_TRUE(called_kernel);
397
398 for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
399 called_kernel = called_stackbased_kernel = false;
400 callOp(*op, dummyTensor(key));
401 EXPECT_TRUE(called_stackbased_kernel);
402 EXPECT_FALSE(called_kernel);
403 }
404 }
405
406 struct DummyKernelWithIntParam final : OperatorKernel {
operator ()__anondbb8b74d0111::DummyKernelWithIntParam407 void operator()(Tensor, int64_t) {}
408 };
409
TEST(OperatorRegistrationTest,whenRegisteringMismatchingKernelsInSameOpCall_thenFails)410 TEST(OperatorRegistrationTest, whenRegisteringMismatchingKernelsInSameOpCall_thenFails) {
411 bool called_kernel = false;
412 expectThrows<c10::Error>([&] {
413 auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
414 .kernel<DummyKernelWithIntParam>(c10::DispatchKey::CPU)
415 .kernel<MockKernel>(c10::DispatchKey::CUDA, &called_kernel));
416 }, "Mismatch in kernel C++ signatures");
417 }
418
backend_fallback_kernel(const c10::OperatorHandle & op,c10::Stack * stack)419 void backend_fallback_kernel(const c10::OperatorHandle& op, c10::Stack* stack) {
420 (*stack)[1] = (*stack)[1].toStringRef() + op.schema().name();
421 }
422
TEST(OperatorRegistrationTest,whenRegisteringBackendFallbackKernel_thenCanBeCalled)423 TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernel_thenCanBeCalled) {
424 auto registrar = c10::Dispatcher::singleton().registerFallback(c10::DispatchKey::CPU, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>(), "");
425
426 auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()");
427 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
428 ASSERT_TRUE(op.has_value());
429 auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
430 EXPECT_EQ("hello _test::dummy", stack[1].toStringRef());
431 }
432
TEST(OperatorRegistrationTest,whenRegisteringBackendFallbackKernelForWrongBackend_thenCannotBeCalled)433 TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelForWrongBackend_thenCannotBeCalled) {
434 auto registrar = c10::Dispatcher::singleton().registerFallback(c10::DispatchKey::CUDA, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>(), "");
435
436 auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()");
437 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
438 ASSERT_TRUE(op.has_value());
439 expectThrows<c10::Error>([&] {
440 auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
441 }, "Could not run '_test::dummy' with arguments from the 'CPU' backend.");
442 }
443
444 bool called = false;
445
TEST(OperatorRegistrationTest,whenRegisteringBackendFallbackKernelAndRegularKernelForDifferentBackend_thenRegularKernelCanBeCalled)446 TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndRegularKernelForDifferentBackend_thenRegularKernelCanBeCalled) {
447 auto registrar = c10::Dispatcher::singleton().registerFallback(c10::DispatchKey::CPU, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>(), "");
448
449 auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()", c10::RegisterOperators::options()
450 .kernel(c10::DispatchKey::CUDA, [] (Tensor, std::string) {
451 called = true;
452 }));
453 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
454 ASSERT_TRUE(op.has_value());
455
456 called = false;
457 auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CUDA), "hello ");
458 EXPECT_TRUE(called);
459 }
460
TEST(OperatorRegistrationTest,whenRegisteringBackendFallbackKernelAndRegularKernelForDifferentBackend_thenFallbackKernelCanBeCalled)461 TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndRegularKernelForDifferentBackend_thenFallbackKernelCanBeCalled) {
462 auto registrar = c10::Dispatcher::singleton().registerFallback(c10::DispatchKey::CPU, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>(), "");
463
464 auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()", c10::RegisterOperators::options()
465 .kernel(c10::DispatchKey::CUDA, [] (Tensor, std::string) {
466 called = true;
467 }));
468 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
469 ASSERT_TRUE(op.has_value());
470
471 called = false;
472 auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
473 EXPECT_FALSE(called);
474 EXPECT_EQ("hello _test::dummy", stack[1].toStringRef());
475 }
476
TEST(OperatorRegistrationTest,whenRegisteringBackendFallbackKernelAndRegularKernelForSameBackend_thenCallsRegularKernel)477 TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndRegularKernelForSameBackend_thenCallsRegularKernel) {
478 auto registrar = c10::Dispatcher::singleton().registerFallback(c10::DispatchKey::CPU, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>(), "");
479
480 auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()", c10::RegisterOperators::options()
481 .kernel(c10::DispatchKey::CPU, [] (Tensor, std::string) {
482 called = true;
483 }));
484 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
485 ASSERT_TRUE(op.has_value());
486
487 called = false;
488 auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
489 EXPECT_TRUE(called);
490 }
491
492 bool called_autograd = false;
493 bool called_nonautograd = false;
494
nonautograd_kernel(Tensor a)495 void nonautograd_kernel(Tensor a) {
496 called_nonautograd = true;
497 }
498
autograd_kernel(Tensor a)499 void autograd_kernel(Tensor a) {
500 called_autograd = true;
501 }
502
TEST(OperatorRegistrationTest,whenRegisteringAutogradKernel_thenCanCallAutogradKernel)503 TEST(OperatorRegistrationTest, whenRegisteringAutogradKernel_thenCanCallAutogradKernel) {
504 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
505 .kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd));
506
507 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
508 ASSERT_TRUE(op.has_value());
509
510 called_autograd = false;
511 expectThrows<c10::Error>([&] {
512 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
513 }, "Could not run '_test::dummy' with arguments from the 'CPU'"
514 " backend.");
515
516 op->typed<void(Tensor)>().call(dummyTensor(DispatchKey::CPU, /*requires_grad=*/true));
517 EXPECT_TRUE(called_autograd);
518 }
519
TEST(OperatorRegistrationTest,whenRegisteringAutogradKernelWithRegularKernel_thenCanCallAutogradKernel)520 TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallAutogradKernel) {
521 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
522 .kernel<decltype(nonautograd_kernel), nonautograd_kernel>(DispatchKey::CPU)
523 .kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd));
524
525 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
526 ASSERT_TRUE(op.has_value());
527
528 called_nonautograd = called_autograd = false;
529 op->typed<void (Tensor)>().call(dummyTensor(DispatchKey::CPU, /*requires_grad=*/true));
530 EXPECT_FALSE(called_nonautograd);
531 EXPECT_TRUE(called_autograd);
532 }
533
TEST(OperatorRegistrationTest,whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallCatchallKernel)534 TEST(
535 OperatorRegistrationTest,
536 whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallCatchallKernel) {
537 auto registrar = c10::RegisterOperators().op(
538 "_test::dummy(Tensor dummy) -> ()",
539 c10::RegisterOperators::options()
540 .catchAllKernel<decltype(nonautograd_kernel), nonautograd_kernel>()
541 .kernel<decltype(autograd_kernel), &autograd_kernel>(
542 DispatchKey::Autograd));
543
544 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
545 ASSERT_TRUE(op.has_value());
546
547 // catchAll now maps to CompositeImplicitAutograd which has
548 // higher precedence than Autograd
549 called_nonautograd = called_autograd = false;
550 op->typed<void(Tensor)>().call(
551 dummyTensor(DispatchKey::CPU, /*requires_grad=*/true));
552 EXPECT_TRUE(called_nonautograd);
553 EXPECT_FALSE(called_autograd);
554
555 called_nonautograd = called_autograd = false;
556 op->typed<void(Tensor)>().call(dummyTensor(DispatchKey::CPU));
557 EXPECT_TRUE(called_nonautograd);
558 EXPECT_FALSE(called_autograd);
559 }
560
TEST(OperatorRegistrationTest,AutogradBackendOverridesAutogradKernel)561 TEST(OperatorRegistrationTest, AutogradBackendOverridesAutogradKernel) {
562 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
563 .kernel<decltype(nonautograd_kernel), &nonautograd_kernel>(DispatchKey::AutogradCPU)
564 .kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd));
565
566 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
567 ASSERT_TRUE(op.has_value());
568
569 expectThrows<c10::Error>([&] {
570 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
571 }, "Could not run '_test::dummy' with arguments from the 'CPU'"
572 " backend.");
573
574 expectThrows<c10::Error>([&] {
575 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
576 }, "Could not run '_test::dummy' with arguments from the 'CUDA'"
577 " backend.");
578
579 called_nonautograd = called_autograd = false;
580 op->typed<void (Tensor)>().call(dummyTensor(DispatchKey::CPU, /*requires_grad=*/true));
581 EXPECT_TRUE(called_nonautograd);
582 EXPECT_FALSE(called_autograd);
583
584 called_nonautograd = called_autograd = false;
585 op->typed<void (Tensor)>().call(dummyTensor(DispatchKey::CUDA, /*requires_grad=*/true));
586 EXPECT_TRUE(called_autograd);
587 EXPECT_FALSE(called_nonautograd);
588 }
589
LazyBackendsAutogradOverridesAutogradKernel(DispatchKey key)590 void LazyBackendsAutogradOverridesAutogradKernel(DispatchKey key) {
591 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
592 .kernel<decltype(nonautograd_kernel), &nonautograd_kernel>(c10::getAutogradKeyFromBackend(toBackendComponent(key)))
593 .kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd));
594
595 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
596 ASSERT_TRUE(op.has_value());
597
598 std::string expectedMessage = expectedMessageForBackend(key);
599 expectThrows<c10::Error>([&] {
600 callOp(*op, dummyTensor(key));
601 }, expectedMessage.c_str());
602
603 called_nonautograd = called_autograd = false;
604 op->typed<void (Tensor)>().call(dummyTensor(key, /*requires_grad=*/true));
605 EXPECT_TRUE(called_nonautograd);
606 EXPECT_FALSE(called_autograd);
607
608 called_nonautograd = called_autograd = false;
609 op->typed<void (Tensor)>().call(dummyTensor(DispatchKey::CPU, /*requires_grad=*/true));
610 EXPECT_TRUE(called_autograd);
611 EXPECT_FALSE(called_nonautograd);
612 }
613
614 // no longer test ::Lazy key here
615 // since it is now registered to TS backend in-tree and thus behaves differently,
616 // does not throw the expected 'could not run..' messages
TEST(OperatorRegistrationTest,AutogradXLAOverridesAutogradKernel)617 TEST(OperatorRegistrationTest, AutogradXLAOverridesAutogradKernel) {
618 LazyBackendsAutogradOverridesAutogradKernel(DispatchKey::XLA);
619 }
620
whenRegisterWithLazyBackendsAndCatchAll_AutogradLazyBackendsIsNotFilled(DispatchKey key)621 void whenRegisterWithLazyBackendsAndCatchAll_AutogradLazyBackendsIsNotFilled(DispatchKey key) {
622 {
623 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
624 .catchAllKernel<decltype(nonautograd_kernel), nonautograd_kernel>());
625
626 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
627 ASSERT_TRUE(op.has_value());
628
629 called_nonautograd = called_autograd = false;
630 op->typed<void (Tensor)>().call(dummyTensor(key, /*requires_grad=*/true));
631 EXPECT_TRUE(called_nonautograd);
632 EXPECT_FALSE(called_autograd);
633
634 called_nonautograd = called_autograd = false;
635 op->typed<void (Tensor)>().call(dummyTensor(key));
636 EXPECT_FALSE(called_autograd);
637 EXPECT_TRUE(called_nonautograd);
638 }
639 {
640 auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
641 .kernel<decltype(autograd_kernel), &autograd_kernel>(key)
642 .catchAllKernel<decltype(nonautograd_kernel), nonautograd_kernel>());
643
644 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
645 ASSERT_TRUE(op.has_value());
646
647 // When there's direct registration to XLA / Lazy backend, Autograd{XLA, Lazy} doesn't pick up catchAll
648 // kernel in precompute but just keep fallthrough kernel from backend fallback.
649 // Thus it falls through Autograd{XLA, Lazy} and reaches the kernel at XLA / Lazy key.
650 called_nonautograd = called_autograd = false;
651 op->typed<void (Tensor)>().call(dummyTensor(key, /*requires_grad=*/true));
652 EXPECT_FALSE(called_nonautograd);
653 EXPECT_TRUE(called_autograd);
654
655 called_nonautograd = called_autograd = false;
656 op->typed<void (Tensor)>().call(dummyTensor(key));
657 EXPECT_TRUE(called_autograd);
658 EXPECT_FALSE(called_nonautograd);
659 }
660 }
661
TEST(OperatorRegistrationTest,whenRegisterWithXLAKernelAndCatchAll_AutogradXLAIsNotFilled)662 TEST(OperatorRegistrationTest, whenRegisterWithXLAKernelAndCatchAll_AutogradXLAIsNotFilled) {
663 whenRegisterWithLazyBackendsAndCatchAll_AutogradLazyBackendsIsNotFilled(DispatchKey::XLA);
664 }
665
TEST(OperatorRegistrationTest,whenRegisterWithLazyKernelAndCatchAll_AutogradLazyIsNotFilled)666 TEST(OperatorRegistrationTest, whenRegisterWithLazyKernelAndCatchAll_AutogradLazyIsNotFilled) {
667 whenRegisterWithLazyBackendsAndCatchAll_AutogradLazyBackendsIsNotFilled(DispatchKey::Lazy);
668 }
669
TEST(OperatorRegistrationTest,whenregisteringwithinvalidoverloadname)670 TEST(OperatorRegistrationTest, whenregisteringwithinvalidoverloadname) {
671 expectThrows<c10::Error>([] {
672 auto registrar = c10::RegisterOperators().op("_test::dummy.default", c10::RegisterOperators::options()
673 .kernel(DispatchKey::CPU, [] (const int64_t&) {}));
674 }, "default is not a legal overload name for aten operators");
675 expectThrows<c10::Error>([] {
676 auto registrar = c10::RegisterOperators().op("_test::dummy.__name__", c10::RegisterOperators::options()
677 .kernel(DispatchKey::CPU, [] (const int64_t&) {}));
678 }, "__name__ is not a legal overload name for aten operators");
679 }
680
TEST(OperatorRegistrationTest,givenLambdaKernel_whenRegisteringWithMismatchingCppSignatures_thenFails)681 TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringWithMismatchingCppSignatures_thenFails) {
682 expectThrows<c10::Error>([] {
683 auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
684 .kernel(DispatchKey::CPU, [] (const int64_t&) {})
685 .kernel(DispatchKey::CUDA, [] (int64_t) {}));
686 }, "Mismatch in kernel C++ signatures");
687 }
688
TEST(OperatorRegistrationTest,givenLambdaKernel_whenRegisteringCatchAllAndBackendWithMismatchingCppSignatures_thenFails)689 TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringCatchAllAndBackendWithMismatchingCppSignatures_thenFails) {
690 expectThrows<c10::Error>([] {
691 auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
692 .kernel(DispatchKey::CPU, [] (const int64_t&) {})
693 .catchAllKernel([] (int64_t) {}));
694 }, "Mismatch in kernel C++ signatures");
695 }
696
TEST(OperatorRegistrationTest,givenLambdaKernel_whenRegisteringBackendAndCatchAllWithMismatchingCppSignatures_thenFails)697 TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringBackendAndCatchAllWithMismatchingCppSignatures_thenFails) {
698 expectThrows<c10::Error>([] {
699 auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
700 .catchAllKernel([] (const int64_t&) {})
701 .kernel(DispatchKey::CPU, [] (int64_t) {}));
702 }, "Mismatch in kernel C++ signatures");
703 }
704
TEST(OperatorRegistrationTest,givenLambdaKernel_whenAccessingWithMismatchingCppSignatures_thenFails)705 TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingWithMismatchingCppSignatures_thenFails) {
706 auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
707 .kernel(DispatchKey::CPU, [] (int64_t) {}));
708 expectThrows<c10::Error>([] {
709 c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "")
710 .typed<void(const int64_t&)>();
711 }, "Tried to access or call an operator with a wrong signature.\n operator: _test::dummy(int _0) -> ()");
712 }
713
TEST(OperatorRegistrationTest,givenLambdaKernel_whenAccessingCatchAllWithMismatchingCppSignatures_thenFails)714 TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingCatchAllWithMismatchingCppSignatures_thenFails) {
715 auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
716 .catchAllKernel([] (int64_t) {}));
717 expectThrows<c10::Error>([] {
718 c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "")
719 .typed<void(const int64_t&)>();
720 }, "Tried to access or call an operator with a wrong signature.\n operator: _test::dummy(int _0) -> ()");
721 }
722
TEST(OperatorRegistrationTest,givenTorchLibrary_whenRegisteringWithMismatchingCppSignatures_thenFails)723 TEST(OperatorRegistrationTest, givenTorchLibrary_whenRegisteringWithMismatchingCppSignatures_thenFails) {
724 auto m = MAKE_TORCH_LIBRARY(_test);
725 m.def("dummy(int a) -> ()");
726 m.impl("dummy", DispatchKey::CPU, [] (int64_t) {});
727 expectThrows<c10::Error>([&] {
728 m.impl("dummy", DispatchKey::CUDA, [] (const int64_t&) {});
729 }, "Mismatch in kernel C++ signatures");
730 }
731
TEST(OperatorRegistrationTest,givenTorchLibrary_whenAccessingWithMismatchingCppSignatures_thenFails)732 TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingWithMismatchingCppSignatures_thenFails) {
733 auto m = MAKE_TORCH_LIBRARY(_test);
734 m.def("dummy(int a) -> ()");
735 m.impl("dummy", DispatchKey::CPU, [] (int64_t) {});
736 expectThrows<c10::Error>([] {
737 c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "")
738 .typed<void(const int64_t&)>();
739 }, "Tried to access or call an operator with a wrong signature.\n operator: _test::dummy(int a) -> ()");
740 }
741
TEST(OperatorRegistrationTest,givenTorchLibrary_whenAccessingCatchAllWithMismatchingCppSignatures_thenFails)742 TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingCatchAllWithMismatchingCppSignatures_thenFails) {
743 auto m = MAKE_TORCH_LIBRARY(_test);
744 m.def("dummy(int a) -> ()", [] (int64_t) {});
745 expectThrows<c10::Error>([] {
746 c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "")
747 .typed<void(const int64_t&)>();
748 }, "Tried to access or call an operator with a wrong signature.\n operator: _test::dummy(int a) -> ()");
749 }
750
751 /**
752 * This is used to check that a given type works correctly when passed as input
753 * to or as output from a kernel.
754 *
755 * Call ArgTypeTestKernel<Input, Output>::test(input, inputExpectation, output, outputExpectation, schema)
756 * to test that a kernel with `Input` as input type and `Output` as output types,
757 * when called with `input` fulfills `inputExpectation` inside the kernel, then
758 * returns `output` and the returned value fulfills `outputExpectation`.
759 *
760 * `inputExpectation` and `outputExpectation` should be lambdas that run
761 * googletest expect macros (or use other ways to assert the expectation is met).
762 *
763 * Optionally, you can specify the argument list part of a function schema
764 * (e.g. "(Tensor a) -> Tensor") as an additional argument to use when
765 * registering the kernel. In this case, the operator registration logic will
766 * check that the kernel function signature matches the one you specified.
767 */
768 struct TestModernAPI final {};
769 struct TestLegacyAPI final {};
770 struct TestModernAndLegacyAPI final {};
771
772 template<class InputType, class OutputType = InputType>
773 struct ArgTypeTestKernel final : OperatorKernel {
ArgTypeTestKernel__anondbb8b74d0111::ArgTypeTestKernel774 explicit ArgTypeTestKernel(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output)
775 : input_(std::move(input)), inputExpectation_(std::move(inputExpectation)), output_(std::move(output)) {}
776
operator ()__anondbb8b74d0111::ArgTypeTestKernel777 OutputType operator()(InputType input) const {
778 inputExpectation_(std::move(input));
779 return output_;
780 }
781
test__anondbb8b74d0111::ArgTypeTestKernel782 static void test(TestModernAndLegacyAPI, InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const c10::Stack&)> outputExpectation, const std::string& schema) {
783 test(TestModernAPI(), input, inputExpectation, output, outputExpectation, schema);
784 test(TestLegacyAPI(), input, inputExpectation, output, outputExpectation, schema);
785 }
786
test__anondbb8b74d0111::ArgTypeTestKernel787 static void test(TestModernAPI, InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const c10::Stack&)> outputExpectation, const std::string& schema) {
788 return test_([&] {
789 return c10::RegisterOperators().op("_test::my_op" + schema, c10::RegisterOperators::options().catchAllKernel<ArgTypeTestKernel>(input, inputExpectation, output));
790 }, input, inputExpectation, output, outputExpectation, schema);
791 }
792
test__anondbb8b74d0111::ArgTypeTestKernel793 static void test(TestLegacyAPI, InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const c10::Stack&)> outputExpectation, const std::string& schema) {
794 return test_([&] {
795 return c10::RegisterOperators().op("_test::my_op" + schema, [=] (InputType input) -> OutputType {
796 inputExpectation(std::move(input));
797 return output;
798 });
799 }, input, inputExpectation, output, outputExpectation, schema);
800 }
801
802 private:
test___anondbb8b74d0111::ArgTypeTestKernel803 static void test_(std::function<c10::RegisterOperators()> registration, InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const c10::Stack&)> outputExpectation, const std::string& schema) {
804 auto registry = registration();
805 auto op = Dispatcher::singleton().findSchema({"_test::my_op", ""});
806 ASSERT_TRUE(op.has_value()); // assert schema is registered
807 auto actualOutput = callOp(*op, input);
808 outputExpectation(actualOutput);
809 }
810
811 InputType input_;
812 std::function<void(const InputType&)> inputExpectation_;
813 OutputType output_;
814 std::string schema_;
815 };
816
817 template<class InputType, class OutputType = InputType>
818 struct testArgTypes final {
819 template<class APIType = TestModernAndLegacyAPI>
test__anondbb8b74d0111::testArgTypes820 static void test(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const IValue&)> outputExpectation, const std::string& schema) {
821 // Test with explicitly specified schema
822 ArgTypeTestKernel<InputType, OutputType>::test(
823 APIType(), input, inputExpectation, output, [&] (const c10::Stack& output) {
824 EXPECT_EQ(1, output.size());
825 outputExpectation(output[0]);
826 }, schema
827 );
828
829 // Test with inferred schema
830 ArgTypeTestKernel<InputType, OutputType>::test(
831 APIType(), input, inputExpectation, output, [&] (const c10::Stack& output) {
832 EXPECT_EQ(1, output.size());
833 outputExpectation(output[0]);
834 }, ""
835 );
836
837 // Test taking argument and returning nothing
838 ArgTypeTestKernel<InputType, std::tuple<>>::test(
839 APIType(), input, inputExpectation, {}, [] (const c10::Stack&) {}, ""
840 );
841
842 // Test taking argument and returning multiple outputs
843 ArgTypeTestKernel<InputType, std::tuple<int64_t, OutputType>>::test(
844 APIType(), input, inputExpectation, std::tuple<int64_t, OutputType>{3, output}, [&] (const c10::Stack& output) {
845 EXPECT_EQ(2, output.size());
846 EXPECT_EQ(3, output[0].toInt());
847 outputExpectation(output[1]);
848 }, ""
849 );
850 }
851 };
852
TEST(OperatorRegistrationTest,testAvailableArgTypes)853 TEST(OperatorRegistrationTest, testAvailableArgTypes) {
854 // TODO Test Scalar
855
856 // primitive types
857 testArgTypes<double>::test(
858 1.5, [] (const double& v) {EXPECT_EQ(1.5, v);},
859 2.5, [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());},
860 "(float a) -> float");
861 testArgTypes<int64_t>::test(
862 1, [] (const int64_t& v) {EXPECT_EQ(1, v);},
863 2, [] (const IValue& v) {EXPECT_EQ(2, v.toInt());},
864 "(int a) -> int");
865 testArgTypes<bool>::test(
866 true, [] (const bool& v) {EXPECT_EQ(true, v);},
867 false, [] (const IValue& v) {EXPECT_EQ(false, v.toBool());},
868 "(bool a) -> bool");
869 testArgTypes<bool>::test(
870 false, [] (const bool& v) {EXPECT_EQ(false, v);},
871 true, [] (const IValue& v) {EXPECT_EQ(true, v.toBool());},
872 "(bool a) -> bool");
873 testArgTypes<std::string>::test(
874 "string1", [] (const std::string& v) {EXPECT_EQ("string1", v);},
875 "string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toStringRef());},
876 "(str a) -> str");
877 testArgTypes<Tensor>::test(
878 dummyTensor(c10::DispatchKey::CPU), [] (const Tensor& v) {EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v));},
879 dummyTensor(c10::DispatchKey::CUDA), [] (const IValue& v) {EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.toTensor()));},
880 "(Tensor a) -> Tensor");
881
882
883 // optional types (with has_value() == true)
884 testArgTypes<std::optional<double>>::test(
885 std::optional<double>(1.5), [] (const std::optional<double>& v) {EXPECT_EQ(1.5, v.value());},
886 std::optional<double>(2.5), [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());},
887 "(float? a) -> float?");
888 testArgTypes<std::optional<int64_t>>::test(
889 std::optional<int64_t>(1), [] (const std::optional<int64_t>& v) {EXPECT_EQ(1, v.value());},
890 std::optional<int64_t>(2), [] (const IValue& v) {EXPECT_EQ(2, v.toInt());},
891 "(int? a) -> int?");
892 testArgTypes<std::optional<bool>>::test(
893 std::optional<bool>(true), [] (const std::optional<bool>& v) {EXPECT_EQ(true, v.value());},
894 std::optional<bool>(false), [] (const IValue& v) {EXPECT_EQ(false, v.toBool());},
895 "(bool? a) -> bool?");
896 testArgTypes<std::optional<bool>>::test(
897 std::optional<bool>(false), [] (const std::optional<bool>& v) {EXPECT_EQ(false, v.value());},
898 std::optional<bool>(true), [] (const IValue& v) {EXPECT_EQ(true, v.toBool());},
899 "(bool? a) -> bool?");
900 testArgTypes<std::optional<std::string>>::test(
901 std::optional<std::string>("string1"), [] (const std::optional<std::string>& v) {EXPECT_EQ("string1", v.value());},
902 std::optional<std::string>("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toStringRef());},
903 "(str? a) -> str?");
904 testArgTypes<std::optional<Tensor>>::test(
905 std::optional<Tensor>(dummyTensor(c10::DispatchKey::CPU)), [] (const std::optional<Tensor>& v) {EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.value()));},
906 std::optional<Tensor>(dummyTensor(c10::DispatchKey::CUDA)), [] (const IValue& v) {EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.toTensor()));},
907 "(Tensor? a) -> Tensor?");
908
909
910 // optional types (with has_value() == false)
911 testArgTypes<std::optional<double>>::test(
912 std::optional<double>(std::nullopt), [] (const std::optional<double>& v) {EXPECT_FALSE(v.has_value());},
913 std::optional<double>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
914 "(float? a) -> float?");
915 testArgTypes<std::optional<int64_t>>::test(
916 std::optional<int64_t>(std::nullopt), [] (const std::optional<int64_t>& v) {EXPECT_FALSE(v.has_value());},
917 std::optional<int64_t>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
918 "(int? a) -> int?");
919 testArgTypes<std::optional<bool>>::test(
920 std::optional<bool>(std::nullopt), [] (const std::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
921 std::optional<bool>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
922 "(bool? a) -> bool?");
923 testArgTypes<std::optional<bool>>::test(
924 std::optional<bool>(std::nullopt), [] (const std::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
925 std::optional<bool>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
926 "(bool? a) -> bool?");
927 testArgTypes<std::optional<std::string>>::test(
928 std::optional<std::string>(std::nullopt), [] (const std::optional<std::string>& v) {EXPECT_FALSE(v.has_value());},
929 std::optional<std::string>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
930 "(str? a) -> str?");
931 testArgTypes<std::optional<Tensor>>::test(
932 std::optional<Tensor>(std::nullopt), [] (const std::optional<Tensor>& v) {EXPECT_FALSE(v.has_value());},
933 std::optional<Tensor>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
934 "(Tensor? a) -> Tensor?");
935
936
937 // list types (with empty list)
938 testArgTypes<c10::List<double>>::test(
939 c10::List<double>(), [] (const c10::List<double>& v) {EXPECT_EQ(0, v.size());},
940 c10::List<double>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<double>>().size());},
941 "(float[] a) -> float[]");
942 testArgTypes<c10::List<int64_t>>::test(
943 c10::List<int64_t>(), [] (const c10::List<int64_t>& v) {EXPECT_EQ(0, v.size());},
944 c10::List<int64_t>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<int64_t>>().size());},
945 "(int[] a) -> int[]");
946 testArgTypes<c10::List<bool>>::test(
947 c10::List<bool>(), [] (const c10::List<bool>& v) {EXPECT_EQ(0, v.size());},
948 c10::List<bool>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<bool>>().size());},
949 "(bool[] a) -> bool[]");
950 testArgTypes<c10::List<std::string>>::test(
951 c10::List<std::string>(), [] (const c10::List<std::string>& v) {EXPECT_EQ(0, v.size());},
952 c10::List<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toListRef().size());},
953 "(str[] a) -> str[]");
954
955
956 // list types (with non-empty list)
957 testArgTypes<c10::List<double>>::test(
958 c10::List<double>({1.5, 2.5}), [] (const c10::List<double>& v) {expectListEquals({1.5, 2.5}, v);},
959 c10::List<double>({3.5, 4.5}), [] (const IValue& v) {expectListEquals({3.5, 4.5}, v.to<c10::List<double>>());},
960 "(float[] a) -> float[]");
961 testArgTypes<c10::List<int64_t>>::test(
962 c10::List<int64_t>({1, 2}), [] (const c10::List<int64_t>& v) {expectListEquals({1, 2}, v);},
963 c10::List<int64_t>({3, 4}), [] (const IValue& v) {expectListEquals({3, 4}, v.to<c10::List<int64_t>>());},
964 "(int[] a) -> int[]");
965 testArgTypes<c10::List<bool>>::test(
966 c10::List<bool>({true, false}), [] (const c10::List<bool>& v) {expectListEquals({true, false}, v);},
967 c10::List<bool>({true, false}), [] (const IValue& v) {expectListEquals({true, false}, v.to<c10::List<bool>>());},
968 "(bool[] a) -> bool[]");
969 testArgTypes<c10::List<std::string>>::test(
970 c10::List<std::string>({"first", "second"}), [] (const c10::List<std::string>& v) {expectListEquals({"first", "second"}, v);},
971 c10::List<std::string>({"first", "second"}), [] (const IValue& v) {
972 EXPECT_EQ(2, v.toListRef().size());
973 EXPECT_EQ("first", v.toListRef()[0].toStringRef());
974 EXPECT_EQ("second", v.toListRef()[1].toStringRef());
975 },
976 "(str[] a) -> str[]");
977 testArgTypes<c10::List<Tensor>>::test(
978 c10::List<Tensor>({dummyTensor(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA)}), [] (const c10::List<Tensor>& v) {
979 EXPECT_EQ(2, v.size());
980 EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.get(0)));
981 EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.get(1)));
982 },
983 c10::List<Tensor>({dummyTensor(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU)}), [] (const IValue& v) {
984 EXPECT_EQ(2, v.to<c10::List<at::Tensor>>().size());
985 EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(0)));
986 EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(1)));
987 },
988 "(Tensor[] a) -> Tensor[]");
989
990 // ArrayRef list types (with empty list)
991 testArgTypes<c10::ArrayRef<double>, c10::List<double>>::test(
992 c10::ArrayRef<double>(), [] (c10::ArrayRef<double> v) {EXPECT_EQ(0, v.size());},
993 c10::List<double>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<double>>().size());},
994 "(float[] a) -> float[]");
995 testArgTypes<c10::ArrayRef<int64_t>, c10::List<int64_t>>::test(
996 c10::ArrayRef<int64_t>(), [] (c10::ArrayRef<int64_t> v) {EXPECT_EQ(0, v.size());},
997 c10::List<int64_t>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<int64_t>>().size());},
998 "(int[] a) -> int[]");
999 testArgTypes<c10::ArrayRef<std::string>, c10::List<std::string>>::test(
1000 c10::ArrayRef<std::string>(), [] (c10::ArrayRef<std::string> v) {EXPECT_EQ(0, v.size());},
1001 c10::List<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toListRef().size());},
1002 "(str[] a) -> str[]");
1003
1004
1005 // list types (with non-empty list)
1006 testArgTypes<c10::ArrayRef<double>, c10::List<double>>::test(
1007 c10::ArrayRef<double>({1.5, 2.5}), [] (c10::ArrayRef<double> v) {expectListEquals({1.5, 2.5}, v);},
1008 c10::List<double>({3.5, 4.5}), [] (const IValue& v) {expectListEquals({3.5, 4.5}, v.to<c10::List<double>>());},
1009 "(float[] a) -> float[]");
1010 testArgTypes<c10::ArrayRef<int64_t>, c10::List<int64_t>>::test(
1011 c10::ArrayRef<int64_t>({1, 2}), [] (c10::ArrayRef<int64_t> v) {expectListEquals({1, 2}, v);},
1012 c10::List<int64_t>({3, 4}), [] (const IValue& v) {expectListEquals({3, 4}, v.to<c10::List<int64_t>>());},
1013 "(int[] a) -> int[]");
1014 testArgTypes<c10::ArrayRef<std::string>, c10::List<std::string>>::test(
1015 c10::ArrayRef<std::string>({"first", "second"}), [] (c10::ArrayRef<std::string> v) {expectListEquals({"first", "second"}, v);},
1016 c10::List<std::string>({"first", "second"}), [] (const IValue& v) {
1017 EXPECT_EQ(2, v.toListRef().size());
1018 EXPECT_EQ("first", v.toListRef()[0].toStringRef());
1019 EXPECT_EQ("second", v.toListRef()[1].toStringRef());
1020 },
1021 "(str[] a) -> str[]");
1022 testArgTypes<c10::ArrayRef<Tensor>, c10::List<Tensor>>::test(
1023 c10::ArrayRef<Tensor>({dummyTensor(c10::DispatchKey::CPUTensorId), dummyTensor(c10::DispatchKey::CUDATensorId)}), [] (c10::ArrayRef<Tensor> v) {
1024 EXPECT_EQ(2, v.size());
1025 EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v[0]));
1026 EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v[1]));
1027 },
1028 c10::List<Tensor>({dummyTensor(c10::DispatchKey::CUDATensorId), dummyTensor(c10::DispatchKey::CPUTensorId)}), [] (const IValue& v) {
1029 EXPECT_EQ(2, v.to<c10::List<at::Tensor>>().size());
1030 EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(0)));
1031 EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(1)));
1032 },
1033 "(Tensor[] a) -> Tensor[]");
1034
1035
1036 // std::array list types (with empty list)
1037 testArgTypes<std::array<double, 0>>::test(
1038 std::array<double, 0>(), [] (std::array<double, 0> v) {},
1039 std::array<double, 0>(), [] (const IValue& v) {EXPECT_EQ(0, (v.to<c10::List<double>>().size()));},
1040 "(float[0] a) -> float[0]");
1041 testArgTypes<std::array<int64_t, 0>>::test(
1042 std::array<int64_t, 0>(), [] (std::array<int64_t, 0> v) {},
1043 std::array<int64_t, 0>(), [] (const IValue& v) {EXPECT_EQ(0, (v.to<c10::List<int64_t>>().size()));},
1044 "(int[0] a) -> int[0]");
1045 testArgTypes<std::array<bool, 0>>::test(
1046 std::array<bool, 0>(), [] (std::array<bool, 0> v) {},
1047 std::array<bool, 0>(), [] (const IValue& v) {EXPECT_EQ(0, (v.to<std::array<bool, 0>>().size()));},
1048 "(bool[0] a) -> bool[0]");
1049 testArgTypes<std::array<std::string, 0>>::test(
1050 std::array<std::string, 0>(), [] (std::array<std::string, 0> v) {EXPECT_EQ(0, v.size());},
1051 std::array<std::string, 0>(), [] (const IValue& v) {EXPECT_EQ(0, v.toListRef().size());},
1052 "(str[0] a) -> str[0]");
1053
1054
1055 // std::array list types (with non-empty list)
1056 testArgTypes<std::array<double, 2>>::test(
1057 std::array<double, 2>({1.5, 2.5}), [] (std::array<double, 2> v) {expectListEquals({1.5, 2.5}, v);},
1058 std::array<double, 2>({3.5, 4.5}), [] (const IValue& v) {expectListEquals({3.5, 4.5}, v.to<std::array<double, 2>>());},
1059 "(float[2] a) -> float[2]");
1060 testArgTypes<std::array<int64_t, 2>>::test(
1061 std::array<int64_t, 2>({1, 2}), [] (std::array<int64_t, 2> v) {expectListEquals({1, 2}, v);},
1062 std::array<int64_t, 2>({3, 4}), [] (const IValue& v) {expectListEquals({3, 4}, v.to<std::array<int64_t, 2>>());},
1063 "(int[2] a) -> int[2]");
1064 testArgTypes<std::array<bool, 2>>::test(
1065 std::array<bool, 2>({true, false}), [] (std::array<bool, 2> v) {expectListEquals({true, false}, v);},
1066 std::array<bool, 2>({true, false}), [] (const IValue& v) {expectListEquals({true, false}, v.to<std::array<bool, 2>>());},
1067 "(bool[2] a) -> bool[2]");
1068 testArgTypes<std::array<std::string, 2>>::test(
1069 std::array<std::string, 2>({"first", "second"}), [] (std::array<std::string, 2> v) {expectListEquals({"first", "second"}, v);},
1070 std::array<std::string, 2>({"first", "second"}), [] (const IValue& v) {
1071 EXPECT_EQ(2, v.toListRef().size());
1072 EXPECT_EQ("first", v.toListRef()[0].toStringRef());
1073 EXPECT_EQ("second", v.toListRef()[1].toStringRef());
1074 },
1075 "(str[2] a) -> str[2]");
1076 testArgTypes<std::array<Tensor, 2>>::test(
1077 std::array<Tensor, 2>({dummyTensor(c10::DispatchKey::CPUTensorId), dummyTensor(c10::DispatchKey::CUDATensorId)}), [] (std::array<Tensor, 2> v) {
1078 EXPECT_EQ(2, v.size());
1079 EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v[0]));
1080 EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v[1]));
1081 },
1082 std::array<Tensor, 2>({dummyTensor(c10::DispatchKey::CUDATensorId), dummyTensor(c10::DispatchKey::CPUTensorId)}), [] (const IValue& v) {
1083 EXPECT_EQ(2, v.to<c10::List<at::Tensor>>().size());
1084 EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(0)));
1085 EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(1)));
1086 },
1087 "(Tensor[2] a) -> Tensor[2]");
1088
1089
1090 // deprecated list types (with empty list)
1091 testArgTypes<std::vector<double>>::test<TestLegacyAPI>(
1092 std::vector<double>(), [] (const std::vector<double>& v) {EXPECT_EQ(0, v.size());},
1093 std::vector<double>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<double>>().size());},
1094 "(float[] a) -> float[]");
1095 testArgTypes<std::vector<int64_t>>::test<TestLegacyAPI>(
1096 std::vector<int64_t>(), [] (const std::vector<int64_t>& v) {EXPECT_EQ(0, v.size());},
1097 std::vector<int64_t>(), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<int64_t>>().size());},
1098 "(int[] a) -> int[]");
1099 //Note: vector<bool> is not supported, use List<bool> instead.
1100 testArgTypes<std::vector<std::string>>::test<TestLegacyAPI>(
1101 std::vector<std::string>(), [] (const std::vector<std::string>& v) {EXPECT_EQ(0, v.size());},
1102 std::vector<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toListRef().size());},
1103 "(str[] a) -> str[]");
1104
1105
1106 // deprecated list types (with non-empty list)
1107 testArgTypes<std::vector<double>>::test<TestLegacyAPI>(
1108 std::vector<double>({1.5, 2.5}), [] (const std::vector<double>& v) {expectListEquals({1.5, 2.5}, v);},
1109 std::vector<double>({3.5, 4.5}), [] (const IValue& v) {expectListEquals({3.5, 4.5}, v.to<c10::List<double>>());},
1110 "(float[] a) -> float[]");
1111 testArgTypes<std::vector<int64_t>>::test<TestLegacyAPI>(
1112 std::vector<int64_t>({1, 2}), [] (const std::vector<int64_t>& v) {expectListEquals({1, 2}, v);},
1113 std::vector<int64_t>({3, 4}), [] (const IValue& v) {expectListEquals({3, 4}, v.to<c10::List<int64_t>>());},
1114 "(int[] a) -> int[]");
1115 //Note: vector<bool> is not supported, use List<bool> instead.
1116 testArgTypes<std::vector<std::string>>::test<TestLegacyAPI>(
1117 std::vector<std::string>({"first", "second"}), [] (const std::vector<std::string>& v) {expectListEquals({"first", "second"}, v);},
1118 std::vector<std::string>({"first", "second"}), [] (const IValue& v) {
1119 EXPECT_EQ(2, v.toListRef().size());
1120 EXPECT_EQ("first", v.toListRef()[0].toStringRef());
1121 EXPECT_EQ("second", v.toListRef()[1].toStringRef());
1122 },
1123 "(str[] a) -> str[]");
1124 testArgTypes<std::vector<Tensor>>::test<TestLegacyAPI>(
1125 std::vector<Tensor>({dummyTensor(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA)}), [] (const std::vector<Tensor>& v) {
1126 EXPECT_EQ(2, v.size());
1127 EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.at(0)));
1128 EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.at(1)));
1129 },
1130 std::vector<Tensor>({dummyTensor(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU)}), [] (const IValue& v) {
1131 EXPECT_EQ(2, v.to<c10::List<at::Tensor>>().size());
1132 EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(0)));
1133 EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.to<c10::List<at::Tensor>>().get(1)));
1134 },
1135 "(Tensor[] a) -> Tensor[]");
1136
1137 // Test optional of list (with nullopt)
1138 testArgTypes<std::optional<c10::List<int64_t>>>::test(
1139 std::optional<c10::List<int64_t>>(std::nullopt), [] (const std::optional<c10::List<int64_t>>& v) {EXPECT_FALSE(v.has_value());},
1140 std::optional<c10::List<int64_t>>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
1141 "(int[]? a) -> int[]?");
1142
1143 // Test optional of list (with empty list)
1144 testArgTypes<std::optional<c10::List<int64_t>>>::test(
1145 std::optional<c10::List<int64_t>>(c10::List<int64_t>({})), [] (const std::optional<c10::List<int64_t>>& v) {EXPECT_EQ(0, v.value().size());},
1146 std::optional<c10::List<int64_t>>(c10::List<int64_t>({})), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<int64_t>>().size());},
1147 "(int[]? a) -> int[]?");
1148
1149 // Test optional of list (with values)
1150 testArgTypes<std::optional<c10::List<int64_t>>>::test(
1151 std::optional<c10::List<int64_t>>(c10::List<int64_t>({1, 2})), [] (const std::optional<c10::List<int64_t>>& v) {expectListEquals({1, 2}, v.value());},
1152 std::optional<c10::List<int64_t>>(c10::List<int64_t>({3, 4})), [] (const IValue& v) {expectListEquals({3, 4}, v.to<c10::List<int64_t>>());},
1153 "(int[]? a) -> int[]?");
1154
1155 // Test list of optional (with empty list)
1156 testArgTypes<c10::List<::std::optional<int64_t>>>::test(
1157 c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({})), [] (const c10::List<::std::optional<int64_t>>& v) {EXPECT_EQ(0, v.size());},
1158 c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({})), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<::std::optional<int64_t>>>().size());},
1159 "(int?[] a) -> int?[]");
1160
1161 // Test list of optional (with values)
1162 testArgTypes<c10::List<::std::optional<int64_t>>>::test(
1163 c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({3, std::nullopt, 2})), [] (const c10::List<::std::optional<int64_t>>& v) {expectListEquals<std::optional<int64_t>>({3, std::nullopt, 2}, v);},
1164 c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({3, std::nullopt, 2})), [] (const IValue& v) {expectListEquals<std::optional<int64_t>>({3, std::nullopt, 2}, v.to<c10::List<::std::optional<int64_t>>>());},
1165 "(int?[] a) -> int?[]");
1166
1167 // dict types
1168 c10::Dict<std::string, std::string> str_dict;
1169 str_dict.insert("key1", "value1");
1170 str_dict.insert("key2", "value2");
1171 testArgTypes<c10::Dict<std::string, std::string>>::test(
1172 str_dict, [] (c10::Dict<std::string, std::string> v) {
1173 EXPECT_EQ(2, v.size());
1174 EXPECT_EQ("value1", v.at("key1"));
1175 EXPECT_EQ("value2", v.at("key2"));
1176 },
1177 str_dict, [] (const IValue& v) {
1178 c10::Dict<std::string, std::string> dict = c10::impl::toTypedDict<std::string, std::string>(v.toGenericDict());
1179 EXPECT_EQ(2, dict.size());
1180 EXPECT_EQ("value1", dict.at("key1"));
1181 EXPECT_EQ("value2", dict.at("key2"));
1182 },
1183 "(Dict(str, str) a) -> Dict(str, str)");
1184 c10::Dict<int64_t, Tensor> tensor_dict;
1185 tensor_dict.insert(1, dummyTensor(c10::DispatchKey::CPU));
1186 tensor_dict.insert(2, dummyTensor(c10::DispatchKey::CUDA));
1187 testArgTypes<c10::Dict<int64_t, Tensor>>::test(
1188 tensor_dict, [] (c10::Dict<int64_t, Tensor> v) {
1189 EXPECT_EQ(2, v.size());
1190 EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.at(1)));
1191 EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.at(2)));
1192 },
1193 tensor_dict, [] (const IValue& v) {
1194 c10::Dict<int64_t, Tensor> dict = c10::impl::toTypedDict<int64_t, Tensor>(v.toGenericDict());
1195 EXPECT_EQ(2, dict.size());
1196 EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(dict.at(1)));
1197 EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(dict.at(2)));
1198 },
1199 "(Dict(int, Tensor) a) -> Dict(int, Tensor)");
1200
1201 // deprecated dict types
1202 std::unordered_map<std::string, std::string> str_map;
1203 str_map.emplace("key1", "value1");
1204 str_map.emplace("key2", "value2");
1205 testArgTypes<std::unordered_map<std::string, std::string>>::test<TestLegacyAPI>(
1206 str_map, [] (std::unordered_map<std::string, std::string> v) {
1207 EXPECT_EQ(2, v.size());
1208 EXPECT_EQ("value1", v.at("key1"));
1209 EXPECT_EQ("value2", v.at("key2"));
1210 },
1211 str_map, [] (const IValue& v) {
1212 c10::Dict<std::string, std::string> dict = c10::impl::toTypedDict<std::string, std::string>(v.toGenericDict());
1213 EXPECT_EQ(2, dict.size());
1214 EXPECT_EQ("value1", dict.at("key1"));
1215 EXPECT_EQ("value2", dict.at("key2"));
1216 },
1217 "(Dict(str, str) a) -> Dict(str, str)");
1218 std::unordered_map<int64_t, Tensor> tensor_map;
1219 tensor_map.emplace(1, dummyTensor(c10::DispatchKey::CPU));
1220 tensor_map.emplace(2, dummyTensor(c10::DispatchKey::CUDA));
1221 testArgTypes<std::unordered_map<int64_t, Tensor>>::test<TestLegacyAPI>(
1222 tensor_map, [] (std::unordered_map<int64_t, Tensor> v) {
1223 EXPECT_EQ(2, v.size());
1224 EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(v.at(1)));
1225 EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(v.at(2)));
1226 },
1227 tensor_map, [] (const IValue& v) {
1228 c10::Dict<int64_t, Tensor> dict = c10::impl::toTypedDict<int64_t, Tensor>(v.toGenericDict());
1229 EXPECT_EQ(2, dict.size());
1230 EXPECT_EQ(c10::DispatchKey::CPU, extractDispatchKey(dict.at(1)));
1231 EXPECT_EQ(c10::DispatchKey::CUDA, extractDispatchKey(dict.at(2)));
1232 },
1233 "(Dict(int, Tensor) a) -> Dict(int, Tensor)");
1234
1235 // weird deeply nested type
1236 using DeeplyNestedType = c10::List<c10::Dict<std::string, c10::List<::std::optional<c10::Dict<int64_t, std::string>>>>>;
1237 auto makeDeeplyNestedObject = [] () -> DeeplyNestedType {
1238 c10::Dict<int64_t, std::string> inner3;
1239 inner3.insert(1, "1");
1240 c10::List<::std::optional<c10::Dict<int64_t, std::string>>> inner2;
1241 inner2.push_back(std::move(inner3));
1242 c10::Dict<std::string, c10::List<::std::optional<c10::Dict<int64_t, std::string>>>> inner1;
1243 inner1.insert("key", std::move(inner2));
1244 c10::List<c10::Dict<std::string, c10::List<::std::optional<c10::Dict<int64_t, std::string>>>>> result;
1245 result.push_back(inner1);
1246 return result;
1247 };
1248 testArgTypes<DeeplyNestedType>::test(
1249 makeDeeplyNestedObject(), [] (const DeeplyNestedType& v) {EXPECT_EQ("1", v.get(0).at("key").get(0).value().at(1));},
1250 makeDeeplyNestedObject(), [] (const IValue& v) {EXPECT_EQ("1", v.to<DeeplyNestedType>().get(0).at("key").get(0).value().at(1));},
1251 "(Dict(str, Dict(int, str)?[])[] a) -> Dict(str, Dict(int, str)?[])[]");
1252 }
1253
TEST(NewOperatorRegistrationTest,erroroutwithinvalidoverloadname)1254 TEST(NewOperatorRegistrationTest, erroroutwithinvalidoverloadname) {
1255 auto m = MAKE_TORCH_LIBRARY(_test);
1256 expectThrows<c10::Error>([&] {
1257 m.def("dummy.default(Tensor self) -> Tensor");
1258 }, "default is not a legal overload name for aten operators");
1259 expectThrows<c10::Error>([&] {
1260 m.def("dummy.__name__(Tensor self) -> Tensor");
1261 }, "__name__ is not a legal overload name for aten operators");
1262 }
1263
TEST(NewOperatorRegistrationTest,testBasics)1264 TEST(NewOperatorRegistrationTest, testBasics) {
1265 auto m = MAKE_TORCH_LIBRARY(_test);
1266 m.def("dummy(Tensor self) -> Tensor");
1267 m.def("dummy1(Tensor self) -> Tensor");
1268 m.def("dummy2(Tensor self) -> Tensor");
1269 m.def("dummy3(Tensor self, Tensor other) -> Tensor", [](const Tensor& self, const Tensor& other) { return self; });
1270 m.def("dummy4", [](const Tensor& self, const Tensor& other) { return other; });
1271 m.impl("dummy", c10::DeviceType::CPU, [](const Tensor& self) { return self; });
1272 m.impl("dummy", c10::DeviceType::XLA, [](const Tensor& self) { return self; });
1273 m.impl("dummy", c10::DeviceType::Lazy, [](const Tensor& self) { return self; });
1274 // Internal API
1275 m.impl("dummy2", c10::DispatchKey::CPU, [](const Tensor& self) { return self; });
1276 m.impl("dummy2", c10::DispatchKey::XLA, [](const Tensor& self) { return self; });
1277 m.impl("dummy2", c10::DispatchKey::Lazy, [](const Tensor& self) { return self; });
1278
1279 ASSERT_TRUE(Dispatcher::singleton().findSchema({"_test::dummy", ""}).has_value());
1280 // Should have a schema even if there are no impls
1281 ASSERT_TRUE(Dispatcher::singleton().findSchema({"_test::dummy1", ""}).has_value());
1282 ASSERT_TRUE(Dispatcher::singleton().findSchema({"_test::dummy2", ""}).has_value());
1283 ASSERT_TRUE(Dispatcher::singleton().findSchema({"_test::dummy3", ""}).has_value());
1284 ASSERT_TRUE(Dispatcher::singleton().findSchema({"_test::dummy4", ""}).has_value());
1285 }
1286
TEST(NewOperatorRegistrationTest,importTopLevel)1287 TEST(NewOperatorRegistrationTest, importTopLevel) {
1288 auto m = MAKE_TORCH_LIBRARY(test);
1289 m.def("def1(Tensor self) -> Tensor");
1290 m.def("def2(Tensor self) -> Tensor", [](const Tensor& x) { return x; });
1291 m.def("def3", [](const Tensor& x) { return x; });
1292
1293 auto m2 = MAKE_TORCH_LIBRARY_IMPL(test, CatchAll);
1294 m2.impl("impl1", [](const Tensor& x) { return x; });
1295
1296 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def1", ""}).has_value());
1297 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def2", ""}).has_value());
1298 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def3", ""}).has_value());
1299 ASSERT_TRUE(Dispatcher::singleton().findOp({"test::def1", ""}).has_value());
1300 ASSERT_TRUE(Dispatcher::singleton().findOp({"test::def2", ""}).has_value());
1301 ASSERT_TRUE(Dispatcher::singleton().findOp({"test::def3", ""}).has_value());
1302 ASSERT_TRUE(Dispatcher::singleton().findOp({"test::impl1", ""}).has_value());
1303 }
1304
TEST(NewOperatorRegistrationTest,overload)1305 TEST(NewOperatorRegistrationTest, overload) {
1306 auto m = MAKE_TORCH_LIBRARY(test);
1307 m.def("fn(Tensor self) -> Tensor");
1308 m.def("fn.overload1(Tensor self, Tensor other) -> Tensor");
1309 m.def("fn.overload2(Tensor self, Tensor other, Tensor alpha) -> Tensor");
1310
1311 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::fn", ""}).has_value());
1312 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::fn", "overload1"}).has_value());
1313 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::fn", "overload2"}).has_value());
1314 }
1315
TEST(NewOperatorRegistrationTest,importNamespace)1316 TEST(NewOperatorRegistrationTest, importNamespace) {
1317 auto m = MAKE_TORCH_LIBRARY(test);
1318 m.def("def1(Tensor self) -> Tensor");
1319 m.def("def2(Tensor self) -> Tensor", [](const Tensor& x) { return x; });
1320 m.def("def3", [](const Tensor& x) { return x; });
1321 m.impl("impl1", [](const Tensor& x) { return x; });
1322 expectThrows<c10::Error>([&] {
1323 m.def("retest::def1(Tensor self) -> Tensor");
1324 }, "");
1325
1326 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def1", ""}).has_value());
1327 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def2", ""}).has_value());
1328 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def3", ""}).has_value());
1329 ASSERT_TRUE(Dispatcher::singleton().findOp({"test::impl1", ""}).has_value());
1330 }
1331
TEST(NewOperatorRegistrationTest,schema)1332 TEST(NewOperatorRegistrationTest, schema) {
1333 auto m = MAKE_TORCH_LIBRARY(test);
1334 m.def("def1(Tensor self) -> Tensor");
1335 m.def(torch::schema("def2(Tensor self) -> Tensor"));
1336 m.def(torch::schema("def3(Tensor self) -> Tensor", c10::AliasAnalysisKind::PURE_FUNCTION));
1337 m.def(torch::jit::parseSchema("def4(Tensor self) -> Tensor"));
1338
1339 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def1", ""}).has_value());
1340 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def2", ""}).has_value());
1341 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def3", ""}).has_value());
1342 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def4", ""}).has_value());
1343
1344 EXPECT_EQ(Dispatcher::singleton().findSchema({"test::def1", ""})->schema().aliasAnalysis(), c10::AliasAnalysisKind::FROM_SCHEMA);
1345 EXPECT_EQ(Dispatcher::singleton().findSchema({"test::def2", ""})->schema().aliasAnalysis(), c10::AliasAnalysisKind::FROM_SCHEMA);
1346 EXPECT_EQ(Dispatcher::singleton().findSchema({"test::def3", ""})->schema().aliasAnalysis(), c10::AliasAnalysisKind::PURE_FUNCTION);
1347 ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def4", ""})->schema().isDefaultAliasAnalysisKind());
1348 }
1349
TEST(NewOperatorRegistrationTest,whenRegisteringBackendFallbackKernelAndCatchallKernelForSameBackend_thenCallsFallbackKernel)1350 TEST(NewOperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndCatchallKernelForSameBackend_thenCallsFallbackKernel) {
1351 auto m1 = MAKE_TORCH_LIBRARY_IMPL(_, CPU);
1352 m1.fallback(CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
1353
1354 bool called = false;
1355 auto m = MAKE_TORCH_LIBRARY(test);
1356 m.def("fn(Tensor t, str input) -> ()");
1357 m.impl("fn", [&] (Tensor, std::string) { called = true; });
1358
1359 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1360 ASSERT_TRUE(op.has_value());
1361
1362 called = false;
1363 auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
1364 // CatchAll now maps to CompositeImplicitAutograd and has higher precedence than backend fallback.
1365 EXPECT_TRUE(called);
1366 }
1367
TEST(NewOperatorRegistrationTest,whenRegisteringAutogradKernelWithRegularKernel_thenCanCallRegularKernel)1368 TEST(NewOperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallRegularKernel) {
1369 auto m = MAKE_TORCH_LIBRARY(test);
1370 m.def("fn(Tensor dummy) -> ()");
1371 m.impl("fn", c10::DispatchKey::CPU, nonautograd_kernel);
1372 m.impl("fn", c10::DispatchKey::Autograd, autograd_kernel);
1373
1374 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1375 ASSERT_TRUE(op.has_value());
1376
1377 called_nonautograd = called_autograd = false;
1378 callOp(*op, dummyTensor(DispatchKey::CPU));
1379 EXPECT_TRUE(called_nonautograd);
1380 EXPECT_FALSE(called_autograd);
1381 }
1382
TEST(NewOperatorRegistrationTest,dispatchWithCompositeImplicitAutogradKernel)1383 TEST(NewOperatorRegistrationTest, dispatchWithCompositeImplicitAutogradKernel) {
1384 bool math_called = false;
1385 auto m = MAKE_TORCH_LIBRARY(test);
1386 m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }));
1387
1388 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1389 ASSERT_TRUE(op.has_value());
1390
1391 {
1392 ASSERT_FALSE(math_called);
1393 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1394 ASSERT_TRUE(math_called);
1395 }
1396
1397 {
1398 math_called = false;
1399 callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1400 ASSERT_TRUE(math_called);
1401 }
1402
1403 for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1404 math_called = false;
1405 callOp(*op, dummyTensor(key));
1406 ASSERT_TRUE(math_called);
1407 }
1408
1409 for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1410 math_called = false;
1411 callOp(*op, dummyTensor(key, /*requires_grad=*/true));
1412 ASSERT_TRUE(math_called);
1413 }
1414
1415 {
1416 math_called = false;
1417 callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU));
1418 ASSERT_TRUE(math_called);
1419 }
1420
1421 {
1422 math_called = false;
1423 callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true));
1424 ASSERT_TRUE(math_called);
1425 }
1426 }
1427
TEST(NewOperatorRegistrationTest,dispatchWithCompositeImplicitAutogradAndAutogradKernel)1428 TEST(NewOperatorRegistrationTest, dispatchWithCompositeImplicitAutogradAndAutogradKernel) {
1429 bool math_called = false;
1430 bool autograd_called = false;
1431 auto m = MAKE_TORCH_LIBRARY(test);
1432 m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }));
1433 m.impl("fn", c10::DispatchKey::Autograd, [&](const Tensor& x) { autograd_called = true; return x; });
1434
1435 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1436 ASSERT_TRUE(op.has_value());
1437
1438 // CompositeImplicitAutograd has higher precedence than Autograd
1439 {
1440 math_called = autograd_called = false;
1441 callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1442 ASSERT_TRUE(math_called);
1443 ASSERT_FALSE(autograd_called);
1444 }
1445
1446 {
1447 math_called = autograd_called = false;
1448 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1449 ASSERT_TRUE(math_called);
1450 ASSERT_FALSE(autograd_called);
1451 }
1452 }
1453
TEST(NewOperatorRegistrationTest,dispatchWithCompositeImplicitAutogradAndCatchAllKernel)1454 TEST(NewOperatorRegistrationTest, dispatchWithCompositeImplicitAutogradAndCatchAllKernel) {
1455 bool math_called = false;
1456 bool catchall_called = false;
1457 auto m = MAKE_TORCH_LIBRARY(test);
1458 m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }));
1459 m.impl("fn", [&](const Tensor& x) { catchall_called = true; return x; });
1460
1461 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1462 ASSERT_TRUE(op.has_value());
1463
1464 // catchAll now maps to CompositeImplicitAutograd, which means we have two registrations to CompositeImplicitAutograd key.
1465 // The last registration is used.
1466 {
1467 catchall_called = math_called = false;
1468 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1469 ASSERT_FALSE(math_called);
1470 ASSERT_TRUE(catchall_called);
1471 }
1472
1473 {
1474 catchall_called = math_called = false;
1475 callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1476 ASSERT_FALSE(math_called);
1477 ASSERT_TRUE(catchall_called);
1478 }
1479 }
1480
TEST(NewOperatorRegistrationTest,AutogradBackendOverridesCompositeImplicitAutogradKernel)1481 TEST(NewOperatorRegistrationTest, AutogradBackendOverridesCompositeImplicitAutogradKernel) {
1482 bool math_called = false;
1483 bool autograd_called = false;
1484 auto m = MAKE_TORCH_LIBRARY(test);
1485 m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }));
1486 m.impl("fn", c10::DispatchKey::AutogradCPU, [&](const Tensor& x) { autograd_called = true; return x; });
1487
1488 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1489 ASSERT_TRUE(op.has_value());
1490
1491 {
1492 math_called = autograd_called = false;
1493 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1494 ASSERT_TRUE(math_called);
1495 ASSERT_FALSE(autograd_called);
1496 }
1497
1498 {
1499 math_called = autograd_called = false;
1500 callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1501 ASSERT_TRUE(autograd_called);
1502 ASSERT_FALSE(math_called);
1503 }
1504
1505 {
1506 math_called = autograd_called = false;
1507 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
1508 ASSERT_TRUE(math_called);
1509 ASSERT_FALSE(autograd_called);
1510 }
1511
1512 {
1513 math_called = autograd_called = false;
1514 callOp(*op, dummyTensor(c10::DispatchKey::CUDA, /*requires_grad=*/true));
1515 ASSERT_TRUE(math_called);
1516 ASSERT_FALSE(autograd_called);
1517 }
1518 }
1519
TEST(NewOperatorRegistrationTest,BackendOverridesCompositeImplicitAutogradKernel)1520 TEST(NewOperatorRegistrationTest, BackendOverridesCompositeImplicitAutogradKernel) {
1521 bool math_called = false;
1522 bool backend_called = false;
1523 auto m = MAKE_TORCH_LIBRARY(test);
1524 m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }));
1525 m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { backend_called = true; return x; });
1526
1527 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1528 ASSERT_TRUE(op.has_value());
1529
1530 {
1531 math_called = backend_called = false;
1532 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1533 ASSERT_TRUE(backend_called);
1534 ASSERT_FALSE(math_called);
1535 }
1536
1537 {
1538 // Fallthrough AutogradCPU and reaches CPU
1539 math_called = backend_called = false;
1540 callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1541 ASSERT_TRUE(backend_called);
1542 ASSERT_FALSE(math_called);
1543 }
1544
1545 {
1546 math_called = backend_called = false;
1547 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
1548 ASSERT_TRUE(math_called);
1549 ASSERT_FALSE(backend_called);
1550 }
1551
1552 {
1553 math_called = backend_called = false;
1554 callOp(*op, dummyTensor(c10::DispatchKey::CUDA, /*requires_grad=*/true));
1555 ASSERT_TRUE(math_called);
1556 ASSERT_FALSE(backend_called);
1557 }
1558 }
1559
TEST(NewOperatorRegistrationTest,dispatchWithCompositeExplicitAutogradKernel)1560 TEST(NewOperatorRegistrationTest, dispatchWithCompositeExplicitAutogradKernel) {
1561 bool called = false;
1562 auto m = MAKE_TORCH_LIBRARY(test);
1563 m.def("fn", torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, [&](const Tensor& x) { called = true; return x; }));
1564
1565 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1566 ASSERT_TRUE(op.has_value());
1567
1568 {
1569 ASSERT_FALSE(called);
1570 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1571 ASSERT_TRUE(called);
1572 }
1573
1574 {
1575 called = false;
1576 // AutogradCPU is fallthrough, calls CPU kernel
1577 callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1578 ASSERT_TRUE(called);
1579 }
1580
1581 for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1582 called = false;
1583 callOp(*op, dummyTensor(key));
1584 ASSERT_TRUE(called);
1585 }
1586
1587 for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1588 called = false;
1589 // Autograd{XLA, Lazy} is fallthrough, calls XLA / Lazy kernel
1590 callOp(*op, dummyTensor(key, /*requires_grad=*/true));
1591 ASSERT_TRUE(called);
1592 }
1593
1594 {
1595 called = false;
1596 callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU));
1597 ASSERT_TRUE(called);
1598 }
1599
1600 {
1601 called = false;
1602 // AutogradCPU is fallthrough, calls CPU kernel
1603 callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true));
1604 ASSERT_TRUE(called);
1605 }
1606 }
1607
TEST(NewOperatorRegistrationTest,dispatchWithCompositeExplicitAutogradAndCompositeImplicitAutogradKernel)1608 TEST(NewOperatorRegistrationTest, dispatchWithCompositeExplicitAutogradAndCompositeImplicitAutogradKernel) {
1609 bool backend_called = false;
1610 bool math_called = false;
1611 auto m = MAKE_TORCH_LIBRARY(test);
1612 m.def("fn", torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, [&](const Tensor& x) { backend_called = true; return x; }));
1613 m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; });
1614
1615 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1616 ASSERT_TRUE(op.has_value());
1617
1618 {
1619 backend_called = math_called = false;
1620 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1621 ASSERT_TRUE(backend_called);
1622 ASSERT_FALSE(math_called);
1623 }
1624
1625 {
1626 backend_called = math_called = false;
1627 // AutogradCPU is fallthrough, calls CPU kernel
1628 callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1629 ASSERT_FALSE(math_called);
1630 ASSERT_TRUE(backend_called);
1631 }
1632
1633 for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1634 backend_called = math_called = false;
1635 callOp(*op, dummyTensor(key));
1636 ASSERT_TRUE(backend_called);
1637 ASSERT_FALSE(math_called);
1638 }
1639
1640 for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1641 backend_called = math_called = false;
1642 // Autograd{XLA, Lazy} is fallthrough, calls XLA / Lazy kernel
1643 callOp(*op, dummyTensor(key, /*requires_grad=*/true));
1644 ASSERT_FALSE(math_called);
1645 ASSERT_TRUE(backend_called);
1646 }
1647
1648 {
1649 backend_called = math_called = false;
1650 callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU));
1651 ASSERT_TRUE(backend_called);
1652 ASSERT_FALSE(math_called);
1653 }
1654
1655 {
1656 backend_called = math_called = false;
1657 // AutogradOther is fallthrough, calls SparseCPU kernel
1658 callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true));
1659 ASSERT_FALSE(math_called);
1660 ASSERT_TRUE(backend_called);
1661 }
1662 }
1663
TEST(NewOperatorRegistrationTest,BackendOverridesCompositeExplicitAutogradKernel)1664 TEST(NewOperatorRegistrationTest, BackendOverridesCompositeExplicitAutogradKernel) {
1665 bool default_called = false;
1666 bool backend_called = false;
1667 auto m = MAKE_TORCH_LIBRARY(test);
1668 m.def("fn", torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, [&](const Tensor& x) { default_called = true; return x; }));
1669 m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { backend_called = true; return x; });
1670
1671 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1672 ASSERT_TRUE(op.has_value());
1673
1674 {
1675 default_called = backend_called = false;
1676 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1677 ASSERT_TRUE(backend_called);
1678 ASSERT_FALSE(default_called);
1679 }
1680
1681 {
1682 default_called = backend_called = false;
1683 // AutogradCPU is fallthrough, calls CPU kernel
1684 callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1685 ASSERT_TRUE(backend_called);
1686 ASSERT_FALSE(default_called);
1687 }
1688
1689 {
1690 default_called = backend_called = false;
1691 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
1692 ASSERT_TRUE(default_called);
1693 ASSERT_FALSE(backend_called);
1694 }
1695
1696 {
1697 default_called = backend_called = false;
1698 // AutogradCUDA is fallthrough, calls CUDA kernel
1699 callOp(*op, dummyTensor(c10::DispatchKey::CUDA, /*requires_grad=*/true));
1700 ASSERT_TRUE(default_called);
1701 ASSERT_FALSE(backend_called);
1702 }
1703 }
1704
1705
TEST(NewOperatorRegistrationTest,dispatch)1706 TEST(NewOperatorRegistrationTest, dispatch) {
1707 bool cpu_called = false;
1708 bool cuda_called = false;
1709 bool autograd_called = false;
1710 auto m = MAKE_TORCH_LIBRARY(test);
1711 m.def("fn_cpu", torch::dispatch(c10::DispatchKey::CPU, [&](const Tensor& x) { cpu_called = true; return x; }));
1712 m.def("fn_cuda", torch::dispatch(c10::kCUDA, [&](const Tensor& x) { cuda_called = true; return x; }));
1713 m.def("fn_autograd", torch::dispatch(c10::kAutograd, [&](const Tensor& x) { autograd_called = true; return x; }));
1714
1715 {
1716 auto op = Dispatcher::singleton().findSchema({"test::fn_cpu", ""});
1717 ASSERT_TRUE(op.has_value());
1718 ASSERT_FALSE(cpu_called);
1719 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1720 ASSERT_TRUE(cpu_called);
1721 }
1722
1723 {
1724 auto op = Dispatcher::singleton().findSchema({"test::fn_cuda", ""});
1725 ASSERT_TRUE(op.has_value());
1726 ASSERT_FALSE(cuda_called);
1727 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
1728 ASSERT_TRUE(cuda_called);
1729 }
1730
1731 {
1732 auto op = Dispatcher::singleton().findSchema({"test::fn_autograd", ""});
1733 ASSERT_TRUE(op.has_value());
1734 ASSERT_FALSE(autograd_called);
1735 callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1736 ASSERT_TRUE(autograd_called);
1737 }
1738
1739 for (c10::DispatchKey key : {c10::DispatchKey::XLA, c10::DispatchKey::Lazy}) {
1740 autograd_called = false;
1741 auto op = Dispatcher::singleton().findSchema({"test::fn_autograd", ""});
1742 ASSERT_TRUE(op.has_value());
1743 callOp(*op, dummyTensor(key, /*requires_grad=*/true));
1744 ASSERT_TRUE(autograd_called);
1745 }
1746 }
1747
TEST(NewOperatorRegistrationTest,dispatchAutogradPrecedence)1748 TEST(NewOperatorRegistrationTest, dispatchAutogradPrecedence) {
1749 bool cpu_called = false;
1750 auto m = MAKE_TORCH_LIBRARY(test);
1751 m.def("fn", torch::dispatch(c10::DispatchKey::CPU, [&](const Tensor& x) { cpu_called = true; return x; }));
1752
1753 {
1754 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1755 ASSERT_TRUE(op.has_value());
1756 ASSERT_FALSE(cpu_called);
1757 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1758 ASSERT_TRUE(cpu_called);
1759 }
1760
1761 {
1762 // AutogradCPU is fallthrough, use CPU kernel
1763 cpu_called = false;
1764 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1765 callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1766 ASSERT_TRUE(cpu_called);
1767 }
1768
1769 bool autograd_called = false;
1770 m.impl("fn", c10::kAutograd, [&](const Tensor& x) { autograd_called = true; return x; });
1771
1772 {
1773 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1774 callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1775 ASSERT_TRUE(autograd_called);
1776 }
1777
1778 // Autograd backend kernel has higher precedence than Autograd alias.
1779 bool autogradcpu_called = false;
1780 m.impl("fn", c10::DispatchKey::AutogradCPU, [&](const Tensor& x) { autogradcpu_called = true; return x; });
1781
1782 {
1783 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1784 callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1785 ASSERT_TRUE(autogradcpu_called);
1786 }
1787 }
1788
TEST(NewOperatorRegistrationTest,throwsWhenRegisterToBackendMapsToAutogradOther)1789 TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther) {
1790 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1791 bool fpga_called, math_called = false;
1792 auto m = MAKE_TORCH_LIBRARY(test);
1793 m.def("fn", torch::dispatch(c10::DispatchKey::FPGA, [&](const Tensor& x) { fpga_called = true; return x; }));
1794 m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; });
1795
1796 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1797 ASSERT_TRUE(op.has_value());
1798
1799 {
1800 callOp(*op, dummyTensor(c10::DispatchKey::FPGA));
1801 ASSERT_TRUE(fpga_called);
1802 }
1803
1804 {
1805 expectThrows<c10::Error>([&] {
1806 callOp(*op, dummyTensor(c10::DispatchKey::FPGA, /*requires_grad=*/true));
1807 }, "test::fn has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther.");
1808 }
1809 }
1810
TEST(NewOperatorRegistrationTest,dispatchMultipleTensors)1811 TEST(NewOperatorRegistrationTest, dispatchMultipleTensors) {
1812 bool privateuse1_called = false;
1813 bool catchall_called = false;
1814 // Similar to in-tree AutogradCPU/AutogradCUDA etc, out-of-tree backends usually register
1815 // a fallthrough kernel for AutogradPrivateUse1.
1816 auto m1 = MAKE_TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1);
1817 m1.fallback(CppFunction::makeFallthrough());
1818
1819 auto m = MAKE_TORCH_LIBRARY(test);
1820 m.def("fn", torch::dispatch(c10::DispatchKey::PrivateUse1, [&](const Tensor& x, const Tensor& y) { privateuse1_called = true; return x; }));
1821 m.impl("fn", [&](const Tensor& x, const Tensor& y) { catchall_called = true; return x; });
1822
1823 {
1824 auto op = Dispatcher::singleton().findOp({"test::fn", ""});
1825 ASSERT_TRUE(op.has_value());
1826 callOp(*op, dummyTensor(c10::DispatchKey::PrivateUse1), dummyTensor(c10::DispatchKey::CPU));
1827 ASSERT_TRUE(privateuse1_called);
1828 }
1829
1830 {
1831 auto op = Dispatcher::singleton().findOp({"test::fn", ""});
1832 ASSERT_TRUE(op.has_value());
1833 ASSERT_FALSE(catchall_called);
1834 callOp(*op, dummyTensor(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CPU));
1835 ASSERT_TRUE(catchall_called);
1836 }
1837
1838 {
1839 auto op = Dispatcher::singleton().findOp({"test::fn", ""});
1840 ASSERT_TRUE(op.has_value());
1841 catchall_called = false;
1842 callOp(*op,
1843 dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true),
1844 dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1845 ASSERT_TRUE(catchall_called);
1846 }
1847
1848 {
1849 auto op = Dispatcher::singleton().findOp({"test::fn", ""});
1850 ASSERT_TRUE(op.has_value());
1851 catchall_called = false;
1852 privateuse1_called = false;
1853 callOp(*op,
1854 dummyTensor(c10::DispatchKey::PrivateUse1, /*requires_grad=*/true),
1855 dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1856 ASSERT_FALSE(catchall_called);
1857 ASSERT_TRUE(privateuse1_called);
1858 }
1859
1860 m.impl("fn", c10::DispatchKey::AutogradPrivateUse1, [&](const Tensor& x, const Tensor& y) { privateuse1_called = true; return x; });
1861
1862 {
1863 auto op = Dispatcher::singleton().findOp({"test::fn", ""});
1864 ASSERT_TRUE(op.has_value());
1865 privateuse1_called = false;
1866 callOp(*op,
1867 dummyTensor(c10::DispatchKey::PrivateUse1, /*requires_grad=*/true),
1868 dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1869 ASSERT_TRUE(privateuse1_called);
1870 }
1871 }
1872
TEST(NewOperatorRegistrationTest,registerCompositeImplicitAutogradWithCPUKernel_andCallAutogradOtherKernel_callsComposite)1873 TEST(NewOperatorRegistrationTest, registerCompositeImplicitAutogradWithCPUKernel_andCallAutogradOtherKernel_callsComposite) {
1874 bool math_called = false;
1875 bool cpu_called = false;
1876 auto m = MAKE_TORCH_LIBRARY(test);
1877 m.def("fn(Tensor dummy) -> Tensor");
1878 m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { cpu_called = true; return x; });
1879 m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; });
1880
1881 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1882 ASSERT_TRUE(op.has_value());
1883
1884 {
1885 math_called = cpu_called = false;
1886 // Meta should redispatch to the AutogradOther backend,
1887 // which the composite kernel should be registered to.
1888 callOp(*op, dummyTensor(c10::DispatchKey::Meta, /*requires_grad=*/true));
1889 ASSERT_TRUE(math_called);
1890 ASSERT_FALSE(cpu_called);
1891 }
1892 }
1893
TEST(NewOperatorRegistrationTest,dispatchMultiple)1894 TEST(NewOperatorRegistrationTest, dispatchMultiple) {
1895 bool cpu_called = false;
1896 bool cuda_called = false;
1897 bool autograd_called = false;
1898 auto m = MAKE_TORCH_LIBRARY(test);
1899 m.def("fn(Tensor self) -> Tensor");
1900 // NB: Direct use of DispatchKey is discouraged; use the DeviceType
1901 // k-synonyms instead
1902 m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { cpu_called = true; return x; });
1903 m.impl("fn", c10::kCUDA, [&](const Tensor& x) { cuda_called = true; return x; });
1904 m.impl("fn", c10::kAutograd, [&](const Tensor& x) { autograd_called = true; return x; });
1905
1906 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1907 ASSERT_TRUE(op.has_value());
1908
1909 {
1910 ASSERT_FALSE(cpu_called);
1911 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1912 ASSERT_TRUE(cpu_called);
1913
1914 ASSERT_FALSE(cuda_called);
1915 callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
1916 ASSERT_TRUE(cuda_called);
1917 }
1918
1919 {
1920 ASSERT_FALSE(autograd_called);
1921 callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
1922 ASSERT_TRUE(autograd_called);
1923
1924 autograd_called = false;
1925 callOp(*op, dummyTensor(c10::DispatchKey::CUDA, /*requires_grad=*/true));
1926 ASSERT_TRUE(autograd_called);
1927 }
1928 }
1929
TEST(NewOperatorRegistrationTest,fallback)1930 TEST(NewOperatorRegistrationTest, fallback) {
1931 auto m = MAKE_TORCH_LIBRARY_IMPL(_, CPU);
1932 m.fallback(CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
1933
1934 auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()");
1935
1936 auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
1937 ASSERT_TRUE(op.has_value());
1938 auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
1939 EXPECT_EQ("hello _test::dummy", stack[1].toStringRef());
1940 }
1941
TEST(NewOperatorRegistrationTest,BackendSelectRedispatchesToCPU)1942 TEST(NewOperatorRegistrationTest, BackendSelectRedispatchesToCPU) {
1943 bool cpu_called = false;
1944 bool backend_generic_called = false;
1945 auto m = MAKE_TORCH_LIBRARY(test);
1946 auto after_backend_select = c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::BackendSelect);
1947 m.def("fn(Tensor self) -> Tensor");
1948 m.impl("fn", c10::kCPU, [&](const Tensor& x) { cpu_called = true; return x; });
1949 m.impl("fn", c10::DispatchKey::BackendSelect, [&](c10::DispatchKeySet ks, const Tensor& x) {
1950 backend_generic_called = true;
1951 auto op = c10::Dispatcher::singleton().findSchema({"test::fn", ""}).value().typed<Tensor (const Tensor&)>();
1952 return c10::Dispatcher::singleton().redispatch<Tensor, const Tensor&>(op, ks & after_backend_select, x);
1953 });
1954
1955 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
1956 ASSERT_TRUE(op.has_value());
1957 callOp(*op, dummyTensor(c10::DispatchKey::CPU));
1958 ASSERT_TRUE(cpu_called);
1959 ASSERT_TRUE(backend_generic_called);
1960 }
1961
TEST(NewOperatorRegistrationTest,TorchLibraryTwiceIsError)1962 TEST(NewOperatorRegistrationTest, TorchLibraryTwiceIsError) {
1963 {
1964 auto m = MAKE_TORCH_LIBRARY(test);
1965 expectThrows<c10::Error>([] {
1966 auto m2 = MAKE_TORCH_LIBRARY(test);
1967 }, "Only a single TORCH_LIBRARY");
1968 }
1969 // Ensure it's ok after deregistering
1970 auto m = MAKE_TORCH_LIBRARY(test);
1971 }
1972
dummy_fn(const Tensor & x)1973 Tensor dummy_fn(const Tensor& x) {
1974 return x;
1975 }
1976
TEST(NewOperatorRegistrationTest,CppFunction)1977 TEST(NewOperatorRegistrationTest, CppFunction) {
1978 // Just show off the possible ways to register functions
1979 auto m = MAKE_TORCH_LIBRARY(test);
1980 m.def("fn1", &dummy_fn);
1981 // C++ will implicitly convert function to function pointer
1982 // c.f. https://en.cppreference.com/w/cpp/language/implicit_conversion#Function_to_pointer
1983 m.def("fn2", dummy_fn);
1984 m.def("fn3", [](const Tensor& x) { return x; });
1985 // These require explicit schema
1986 m.def("fn4(Tensor x) -> Tensor", CppFunction::makeFallthrough());
1987 m.def("fn5(Tensor x) -> Tensor", CppFunction::makeFromUnboxedFunction(dummy_fn));
1988 m.def("fn6(Tensor x) -> Tensor", CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
1989 }
1990
1991 // Some internal tests that have to be done from C++
1992
1993 struct OpRegistrationListenerForDelayedListenerTest : public c10::OpRegistrationListener {
1994 int64_t num_registers_ = 0;
1995 int64_t num_deregisters_ = 0;
onOperatorRegistered__anondbb8b74d0111::OpRegistrationListenerForDelayedListenerTest1996 void onOperatorRegistered(const OperatorHandle& op) override {
1997 num_registers_++;
1998 }
onOperatorDeregistered__anondbb8b74d0111::OpRegistrationListenerForDelayedListenerTest1999 void onOperatorDeregistered(const OperatorHandle& op) override {
2000 num_deregisters_++;
2001 }
2002 };
2003
TEST(NewOperatorRegistrationTest,testDelayedListener)2004 TEST(NewOperatorRegistrationTest, testDelayedListener) {
2005 auto listener = std::make_unique<OpRegistrationListenerForDelayedListenerTest>();
2006 auto listener_ptr = listener.get();
2007 auto registry = Dispatcher::singleton().addRegistrationListener(std::move(listener));
2008 int64_t initial_num_registers = listener_ptr->num_registers_;
2009 int64_t initial_num_deregisters = listener_ptr->num_deregisters_;
2010 auto op = Dispatcher::singleton().findOp({"_test::dummy", ""});
2011 ASSERT_FALSE(op.has_value());
2012 auto m1 = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
2013 m1.impl("dummy", [](const Tensor& self) { return self; });
2014 EXPECT_EQ(initial_num_registers, listener_ptr->num_registers_);
2015 {
2016 auto m2 = MAKE_TORCH_LIBRARY(_test);
2017 m2.def("dummy(Tensor self) -> Tensor");
2018 EXPECT_EQ(initial_num_registers + 1, listener_ptr->num_registers_);
2019 }
2020 EXPECT_EQ(initial_num_deregisters + 1, listener_ptr->num_deregisters_);
2021 }
2022
TEST(NewOperatorRegistrationTest,testImplNoDefGetsCaught)2023 TEST(NewOperatorRegistrationTest, testImplNoDefGetsCaught) {
2024 auto danglingImpls = Dispatcher::singleton().findDanglingImpls();
2025 std::string error_str = "Discovered operators that have been registered through the dispatcher"
2026 " without explicitly specifying their schemas. Please do so using"
2027 " the TORCH_LIBRARY macro. Suspect operators:\n";
2028 for (auto& op : danglingImpls) {
2029 auto& op_name = op.operator_name();
2030 error_str += "\t" + op_name.name;
2031 if (op_name.overload_name != "") {
2032 error_str += "." + op_name.overload_name;
2033 }
2034 error_str += "\n";
2035 }
2036 ASSERT_EQ(danglingImpls.size(), 0) << error_str;
2037 }
2038
2039 bool called_kernel_cpu = false;
2040 bool called_kernel_autograd = false;
2041 bool called_kernel_tracing = false;
2042
cpu_kernel(Tensor)2043 void cpu_kernel(Tensor) {
2044 called_kernel_cpu = true;
2045 }
2046
2047 // autograd kernel that redispatches. Explicitly takes in and updates the DispatchKeySet
autograd_kernel_redispatching_with_DispatchKeySet(c10::DispatchKeySet ks,Tensor a)2048 void autograd_kernel_redispatching_with_DispatchKeySet(c10::DispatchKeySet ks, Tensor a) {
2049 called_kernel_autograd = true;
2050 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
2051 auto updatedDispatchKeySet = ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther);
2052 callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, updatedDispatchKeySet, a);
2053 }
2054
2055 // autograd kernel that redispatches. Does not take in a DispatchKeySet
autograd_kernel_redispatching_without_DispatchKeySet(c10::DispatchKeySet ks,Tensor a)2056 void autograd_kernel_redispatching_without_DispatchKeySet(c10::DispatchKeySet ks, Tensor a) {
2057 called_kernel_autograd = true;
2058 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
2059 auto updatedDispatchKeySet = ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther);
2060 callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, updatedDispatchKeySet, a);
2061 }
2062
2063 // tracing kernel that redispatches. Explicitly takes in and updates the DispatchKeySet
tracing_kernel_redispatching_with_DispatchKeySet(c10::DispatchKeySet ks,Tensor a)2064 void tracing_kernel_redispatching_with_DispatchKeySet(c10::DispatchKeySet ks, Tensor a) {
2065 called_kernel_tracing = true;
2066 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
2067 auto updatedDispatchKeySet = ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer);
2068 callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, updatedDispatchKeySet, a);
2069 }
2070
TEST(OperatorRegistrationTest,callKernelsWithDispatchKeySetConvention_call_redispatchesToLowerPriorityKernels)2071 TEST(OperatorRegistrationTest, callKernelsWithDispatchKeySetConvention_call_redispatchesToLowerPriorityKernels) {
2072 auto m = MAKE_TORCH_LIBRARY(test);
2073 m.def("fn(Tensor dummy) -> ()");
2074 m.impl("fn", c10::DispatchKey::CPU, cpu_kernel);
2075 m.impl("fn", c10::DispatchKey::AutogradCPU, autograd_kernel_redispatching_with_DispatchKeySet);
2076 m.impl("fn", c10::DispatchKey::Tracer, tracing_kernel_redispatching_with_DispatchKeySet);
2077
2078 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
2079 ASSERT_TRUE(op.has_value());
2080
2081 called_kernel_cpu = called_kernel_autograd = called_kernel_tracing = false;
2082 auto tracing_autograd_cpu_set = c10::DispatchKeySet()
2083 .add(c10::DispatchKey::Tracer)
2084 .add(c10::DispatchKey::AutogradCPU)
2085 .add(c10::DispatchKey::CPU);
2086
2087 // call Tracing -> call Autograd -> call CPU
2088 callOpUnboxed<void, Tensor>(*op, dummyTensor(tracing_autograd_cpu_set, true));
2089 EXPECT_TRUE(called_kernel_tracing);
2090 EXPECT_TRUE(called_kernel_autograd);
2091 EXPECT_TRUE(called_kernel_cpu);
2092 }
2093
TEST(OperatorRegistrationTest,callKernelsWithDispatchKeySetConvention_callBoxed_redispatchesToLowerPriorityKernels)2094 TEST(OperatorRegistrationTest, callKernelsWithDispatchKeySetConvention_callBoxed_redispatchesToLowerPriorityKernels) {
2095 auto m = MAKE_TORCH_LIBRARY(test);
2096 m.def("fn(Tensor dummy) -> ()");
2097 m.impl("fn", c10::DispatchKey::CPU, cpu_kernel);
2098 m.impl("fn", c10::DispatchKey::AutogradCPU, autograd_kernel_redispatching_with_DispatchKeySet);
2099 m.impl("fn", c10::DispatchKey::Tracer, tracing_kernel_redispatching_with_DispatchKeySet);
2100
2101 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
2102 ASSERT_TRUE(op.has_value());
2103
2104 called_kernel_cpu = called_kernel_autograd = called_kernel_tracing = false;
2105 auto tracing_autograd_cpu_set = c10::DispatchKeySet()
2106 .add(c10::DispatchKey::Tracer)
2107 .add(c10::DispatchKey::AutogradCPU)
2108 .add(c10::DispatchKey::CPU);
2109
2110 // call Tracing -> call Autograd -> call CPU
2111 callOp<Tensor>(*op, dummyTensor(tracing_autograd_cpu_set, true));
2112 EXPECT_TRUE(called_kernel_tracing);
2113 EXPECT_TRUE(called_kernel_autograd);
2114 EXPECT_TRUE(called_kernel_cpu);
2115 }
2116
TEST(OperatorRegistrationTest,callKernelsWithDispatchKeySetConvention_mixedCallingConventions_redispatchesToLowerPriorityKernels)2117 TEST(OperatorRegistrationTest, callKernelsWithDispatchKeySetConvention_mixedCallingConventions_redispatchesToLowerPriorityKernels) {
2118 auto m = MAKE_TORCH_LIBRARY(test);
2119 m.def("fn(Tensor dummy) -> ()");
2120 m.impl("fn", c10::DispatchKey::CPU, cpu_kernel);
2121 // the tracing kernel takes in a DispatchKeySet, but the autograd kernel does not
2122 // the dispatcher should handle correctly plumbing its DispatchKeySet to tracing and not autograd.
2123 m.impl("fn", c10::DispatchKey::AutogradCPU, autograd_kernel_redispatching_without_DispatchKeySet);
2124 m.impl("fn", c10::DispatchKey::Tracer, tracing_kernel_redispatching_with_DispatchKeySet);
2125
2126 auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
2127 ASSERT_TRUE(op.has_value());
2128
2129 called_kernel_cpu = called_kernel_autograd = called_kernel_tracing = false;
2130 auto tracing_autograd_cpu_set = c10::DispatchKeySet()
2131 .add(c10::DispatchKey::Tracer)
2132 .add(c10::DispatchKey::AutogradCPU)
2133 .add(c10::DispatchKey::CPU);
2134
2135 // call Tracing -> call Autograd -> call CPU
2136 callOpUnboxed<void, Tensor>(*op, dummyTensor(tracing_autograd_cpu_set, true));
2137 EXPECT_TRUE(called_kernel_tracing);
2138 EXPECT_TRUE(called_kernel_autograd);
2139 EXPECT_TRUE(called_kernel_cpu);
2140 }
2141
TEST(OperatorRegistrationTest,getRegistrationsForDispatchKey)2142 TEST(OperatorRegistrationTest, getRegistrationsForDispatchKey) {
2143 // should return every registered op
2144 auto all_ops = Dispatcher::singleton().getRegistrationsForDispatchKey(std::nullopt);
2145 // should return every registered op with a cpu kernel
2146 auto cpu_ops = Dispatcher::singleton().getRegistrationsForDispatchKey(c10::DispatchKey::CPU);
2147 ASSERT_TRUE(all_ops.size() > 0);
2148 ASSERT_TRUE(cpu_ops.size() > 0);
2149
2150 auto cmp_lambda = [](const c10::OperatorName a, const c10::OperatorName& b) -> bool {
2151 return c10::toString(a) < c10::toString(b);
2152 };
2153
2154 std::sort(all_ops.begin(), all_ops.end(), cmp_lambda);
2155 std::sort(cpu_ops.begin(), cpu_ops.end(), cmp_lambda);
2156 ASSERT_TRUE(std::includes(all_ops.begin(), all_ops.end(), cpu_ops.begin(), cpu_ops.end(), cmp_lambda));
2157 }
2158
symint_op(const Tensor & self,int64_t length)2159 Tensor symint_op(const Tensor& self, int64_t length) {
2160 return self.clone();
2161 }
2162
TEST(OperatorRegistrationTest,TestSymNonSymCompatibility)2163 TEST(OperatorRegistrationTest, TestSymNonSymCompatibility) {
2164 auto m = MAKE_TORCH_LIBRARY(_test);
2165 m.def("_test::symint_op(Tensor self, SymInt length) -> Tensor");
2166 auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
2167 m_cpu.impl("symint_op", c10::DispatchKey::CPU, TORCH_FN(symint_op));
2168
2169 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
2170 "_test::symint_op", "");
2171
2172 opHandle.typed<Tensor(const Tensor&, int64_t)>().call(dummyTensor(c10::DispatchKey::CPU), 4);
2173 opHandle.typed<Tensor(const Tensor&, c10::SymInt)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
2174
2175 expectThrows<c10::Error>([&] {
2176 opHandle.typed<Tensor(const Tensor&, const c10::SymInt&)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
2177 }, "Tried to access or call an operator with a wrong signature");
2178 }
2179
symint_op2(const Tensor & self,c10::SymInt length)2180 Tensor symint_op2(const Tensor& self, c10::SymInt length) {
2181 return self.clone();
2182 }
2183
TEST(OperatorRegistrationTest,TestSymSymCompatibility)2184 TEST(OperatorRegistrationTest, TestSymSymCompatibility) {
2185 auto m = MAKE_TORCH_LIBRARY(_test);
2186 m.def("_test::symint_op(Tensor self, SymInt length) -> Tensor");
2187 auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
2188 m_cpu.impl("symint_op", c10::DispatchKey::CPU, TORCH_FN(symint_op2));
2189
2190 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
2191 "_test::symint_op", "");
2192
2193 opHandle.typed<Tensor(const Tensor&, int64_t)>().call(dummyTensor(c10::DispatchKey::CPU), 4);
2194 opHandle.typed<Tensor(const Tensor&, c10::SymInt)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
2195 // TODO: We should reject this on principle, but today it accidentally works
2196 // due to going through the boxed calling convention.
2197 //
2198 // First, we attempt to test if const SymInt& has SymInt. It does not,
2199 // because we only accept something as SymInt if it has exactly SymInt in
2200 // its signature. So we check if there is a non-symint kernel. But there is
2201 // no non-SymInt kernel, because we only registered a real SymInt kernel.
2202 // When this occurs, we fall back to the boxed calling convention. And the
2203 // boxed calling convention can deal with const SymInt& fine, as during
2204 // boxing it will just create a SymInt to push onto the argument stack and
2205 // everything is fine.
2206 opHandle.typed<Tensor(const Tensor&, const c10::SymInt&)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
2207 }
2208
symint_op3(const Tensor & self,const c10::SymInt & length)2209 Tensor symint_op3(const Tensor& self, const c10::SymInt& length) {
2210 return self.clone();
2211 }
2212
TEST(OperatorRegistrationTest,TestSymSymRefCompatibility)2213 TEST(OperatorRegistrationTest, TestSymSymRefCompatibility) {
2214 auto m = MAKE_TORCH_LIBRARY(_test);
2215 m.def("_test::symint_op(Tensor self, SymInt length) -> Tensor");
2216 auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
2217
2218 expectThrows<c10::Error>([&] {
2219 m_cpu.impl("symint_op", c10::DispatchKey::CPU, TORCH_FN(symint_op3));
2220 }, "doesn't match the expected function schema");
2221 }
2222
2223 }
2224
2225 #pragma GCC diagnostic pop
2226