xref: /aosp_15_r20/external/pytorch/c10/test/core/DispatchKeySet_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <cstddef>
4 #include <iterator>
5 #include <unordered_set>
6 
7 #include <c10/core/DispatchKeySet.h>
8 #include <c10/util/irange.h>
9 
10 using namespace c10;
11 
12 // This test exists not to be comprehensive, but to more clearly show
13 // what the semantics of DispatchKeySet are.
TEST(DispatchKeySet,ShowSemantics)14 TEST(DispatchKeySet, ShowSemantics) {
15   // the "CPU" dispatch key is an instance of a per-backend-functionality key.
16   // It corresponds to "dense" functionality, "CPU" backend.
17   // This means that it gets a dense functionality bit, and a cpu backend bit
18   // set.
19   auto dense_cpu_set = DispatchKeySet(DispatchKey::CPU);
20   ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense));
21   ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit));
22   ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU));
23 
24   auto dense_lazy_set = DispatchKeySet(DispatchKey::Lazy);
25   ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Dense));
26   ASSERT_TRUE(dense_lazy_set.has_backend(BackendComponent::LazyBit));
27   ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Lazy));
28 
29   // You can think of "Dense/Sparse", and "CPUBit/CUDABit", as "building block"
30   // dispatch keys. You are allowed to directly create keysets out of them!
31   auto dense_cpu_set_from_building_blocks = DispatchKeySet(DispatchKey::Dense) |
32       DispatchKeySet(BackendComponent::CPUBit);
33   ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense));
34   ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit));
35   ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU));
36   ASSERT_EQ(dense_cpu_set, dense_cpu_set_from_building_blocks);
37 
38   // Similarly, the AutogradCUDA key gets 2 bits in the keyset:
39   // The "Autograd" functionality bit, and the "CUDA" backend bit
40   auto autograd_cuda = DispatchKeySet(DispatchKey::AutogradCUDA);
41   ASSERT_TRUE(autograd_cuda.has(DispatchKey::AutogradFunctionality));
42   ASSERT_TRUE(autograd_cuda.has_backend(BackendComponent::CUDABit));
43 
44   // Because DispatchKeySet uses a condensed internal representation, you cannot
45   // use it to represent the FULL cross product of backends and functionalities
46   // for example:
47   auto autograd_dense_cpu_cuda = DispatchKeySet(
48       {DispatchKey::AutogradFunctionality,
49        DispatchKey::Dense,
50        DispatchKey::CUDA,
51        DispatchKey::CPU});
52   // this keyset has all of the building block keys:
53   ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradFunctionality));
54   ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::Dense));
55   ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CUDABit));
56   ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CPUBit));
57 
58   // and it also has the "runtime" keys that correspond to the full
59   // cross-product of functionality
60   ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU));
61   ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU));
62   ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CPU));
63   ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CUDA));
64 
65   // This means that there's no way to represent a keyset with, say, only
66   // Autograd CUDA + Dense CPU. Instead, you should think of a keyset as
67   // inheriting the full set of functionalities + backends of its keys. This
68   // means that the below keysets are all indistinguishable from each other.
69   ASSERT_EQ(
70       autograd_dense_cpu_cuda,
71       DispatchKeySet(
72           {DispatchKey::AutogradCUDA,
73            DispatchKey::AutogradCPU,
74            DispatchKey::CUDA,
75            DispatchKey::CPU}));
76   ASSERT_EQ(
77       autograd_dense_cpu_cuda,
78       DispatchKeySet({DispatchKey::AutogradCUDA, DispatchKey::CPU}));
79   ASSERT_EQ(
80       autograd_dense_cpu_cuda,
81       DispatchKeySet({DispatchKey::CUDA, DispatchKey::AutogradCPU}));
82 
83   // ~~~~~~~~~~ DispatchKeySet iterators ~~~~~~~~~~~
84 
85   // Iterators allow you to iterate individually through the DispatchKey's in a
86   // DispatchKeySet
87   auto empty_set = DispatchKeySet();
88   ASSERT_EQ(*empty_set.begin(), *empty_set.end());
89 
90   // However, only keys that correspond to actual runtime indices of kernels in
91   // the operator table show up when you iterate through a keyset. i.e.
92   // DispatchKey::Dense, and BackendComponent::CPUBit won't show up in an
93   // iterator.
94   auto dense_cpu_iter = dense_cpu_set.begin();
95   ASSERT_EQ(*dense_cpu_iter++, DispatchKey::CPU);
96   ASSERT_EQ(*dense_cpu_iter, *dense_cpu_set.end());
97 
98   auto autograd_dense_cpu_cuda_iter = autograd_dense_cpu_cuda.begin();
99   ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CPU);
100   ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CUDA);
101   ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCPU);
102   ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCUDA);
103   ASSERT_EQ(*autograd_dense_cpu_cuda_iter, *autograd_dense_cpu_cuda.end());
104 
105   // But other "functionality bits" that are not defined per-backend DO get
106   // their own slots in the operator table.
107   auto mixed_keyset = DispatchKeySet(BackendComponent::CPUBit) |
108       DispatchKeySet(
109                           {DispatchKey::FPGA, // runtime key
110                            DispatchKey::Functionalize, // runtime key
111                            DispatchKey::Dense}); // NOT a runtime key
112   auto mixed_iter = mixed_keyset.begin();
113   ASSERT_EQ(*mixed_iter++, DispatchKey::CPU);
114   ASSERT_EQ(*mixed_iter++, DispatchKey::FPGA);
115   ASSERT_EQ(*mixed_iter++, DispatchKey::Functionalize);
116   ASSERT_EQ(*mixed_iter, *mixed_keyset.end());
117 }
118 
TEST(DispatchKeySet,Empty)119 TEST(DispatchKeySet, Empty) {
120   DispatchKeySet empty_set;
121   for (uint8_t i = 0;
122        i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
123        i++) {
124     auto tid = static_cast<DispatchKey>(i);
125     if (tid == DispatchKey::Undefined)
126       continue;
127     ASSERT_FALSE(empty_set.has(tid));
128   }
129   ASSERT_TRUE(empty_set.empty());
130   DispatchKeySet empty_set2;
131   ASSERT_TRUE(empty_set == empty_set2);
132 }
133 
134 // This covers all keys that correspond to a single backend bit, e.g.
135 // BackendComponent::CPUBit. Even though these are NOT runtime keys, we still
136 // allow adding them directly to a keyset
TEST(DispatchKeySet,SingletonBackendComponent)137 TEST(DispatchKeySet, SingletonBackendComponent) {
138   for (const auto i : c10::irange(1, num_backends)) {
139     auto tid = static_cast<DispatchKey>(i);
140     DispatchKeySet sing(tid);
141     ASSERT_EQ(sing, sing);
142     ASSERT_EQ(sing, DispatchKeySet().add(tid));
143     ASSERT_EQ(sing, sing.add(tid));
144     ASSERT_EQ(sing, sing | sing);
145     ASSERT_FALSE(sing.empty());
146     ASSERT_TRUE(sing.has(tid));
147   }
148 }
149 
150 // This covers all keys that correspond to a single functionality bit:
151 // - runtime, not-per-backend functionality keys, e.g.
152 // DispatchKey::FuncTorchBatched
153 // - runtime, "fake backend" keys, e.g. DispatchKey::FPGA
154 // - NOT-runtime, per-backend functionality keys, e.g. DispatchKey::Dense
155 //   Even though it's not a runtime key, we still allow adding it directly to a
156 //   keyset.
157 // DispatchKey::
TEST(DispatchKeySet,SingletonFunctionalityKeys)158 TEST(DispatchKeySet, SingletonFunctionalityKeys) {
159   for (const auto i : c10::irange(1, num_functionality_keys)) {
160     auto tid = static_cast<DispatchKey>(i);
161     DispatchKeySet sing(tid);
162     ASSERT_EQ(sing, sing);
163     ASSERT_EQ(sing, DispatchKeySet().add(tid));
164     ASSERT_EQ(sing, sing.add(tid));
165     ASSERT_EQ(sing, sing | sing);
166     ASSERT_FALSE(sing.empty());
167     ASSERT_TRUE(sing.has(tid));
168     ASSERT_EQ(sing.remove(tid), DispatchKeySet());
169   }
170 }
171 
172 // This covers runtime keys that are per-backend,
173 // and take up more than one bit in a DispatchKeySet. They take up one
174 // functionality bit + one backend bit. e.g. CPU, CUDA, SparseCPU, SparseCUDA,
175 // AutogradCPU, AutogradCUDA
TEST(DispatchKeySet,SingletonPerBackendFunctionalityKeys)176 TEST(DispatchKeySet, SingletonPerBackendFunctionalityKeys) {
177   for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
178        i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
179        i++) {
180     auto tid = static_cast<DispatchKey>(i);
181     // Skip these because they aren't real keys.
182     if (tid == DispatchKey::StartOfDenseBackends ||
183         tid == DispatchKey::StartOfSparseBackends ||
184         tid == DispatchKey::StartOfQuantizedBackends ||
185         tid == DispatchKey::StartOfAutogradFunctionalityBackends) {
186       continue;
187     }
188     DispatchKeySet sing(tid);
189     ASSERT_EQ(sing, sing);
190     ASSERT_EQ(sing, DispatchKeySet().add(tid));
191     ASSERT_EQ(sing, sing.add(tid));
192     ASSERT_EQ(sing, sing | sing);
193     ASSERT_FALSE(sing.empty());
194     ASSERT_TRUE(sing.has(tid));
195 
196     auto functionality_key = toFunctionalityKey(tid);
197     auto backend_key = toBackendComponent(tid);
198     // These two sets should be equivalent:
199     // DispatchKeySet(DispatchKey::CPU)
200     // DispatchKeySet({DispatchKey::Dense, BackendComponent::CPUBit})
201     auto expected_ks =
202         DispatchKeySet(functionality_key) | DispatchKeySet(backend_key);
203     ASSERT_EQ(sing, expected_ks);
204     // These two sets should be equivalent:
205     // DispatchKeySet(DispatchKey::CPU).remove(DispatchKey::Dense)
206     // DispatchKeySet(BackendComponent::CPUBit)
207     expected_ks = DispatchKeySet(toBackendComponent(tid));
208     ASSERT_EQ(sing.remove(tid), expected_ks);
209   }
210 }
211 
TEST(DispatchKeySet,DoubletonPerBackend)212 TEST(DispatchKeySet, DoubletonPerBackend) {
213   for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
214        i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
215        i++) {
216     for (uint8_t j = i + 1;
217          j <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
218          j++) {
219       ASSERT_LT(i, j);
220       auto tid1 = static_cast<DispatchKey>(i);
221       auto tid2 = static_cast<DispatchKey>(j);
222 
223       // Skip these because they aren't real keys.
224       if (tid1 == DispatchKey::StartOfDenseBackends ||
225           tid1 == DispatchKey::StartOfSparseBackends ||
226           tid1 == DispatchKey::StartOfSparseCsrBackends ||
227           tid1 == DispatchKey::StartOfQuantizedBackends ||
228           tid1 == DispatchKey::StartOfNestedTensorBackends ||
229           tid1 == DispatchKey::StartOfAutogradFunctionalityBackends)
230         continue;
231       if (tid2 == DispatchKey::StartOfDenseBackends ||
232           tid2 == DispatchKey::StartOfSparseBackends ||
233           tid2 == DispatchKey::StartOfSparseCsrBackends ||
234           tid2 == DispatchKey::StartOfQuantizedBackends ||
235           tid2 == DispatchKey::StartOfNestedTensorBackends ||
236           tid2 == DispatchKey::StartOfAutogradFunctionalityBackends)
237         continue;
238 
239       auto backend1 = toBackendComponent(tid1);
240       auto backend2 = toBackendComponent(tid2);
241       auto functionality1 = toFunctionalityKey(tid1);
242       auto functionality2 = toFunctionalityKey(tid2);
243 
244       auto combined = DispatchKeySet({tid1, tid2});
245       // The combined set has the backend bits
246       ASSERT_TRUE(combined.has_backend(backend1));
247       ASSERT_TRUE(combined.has_backend(backend2));
248       // and it has the backend bits
249       ASSERT_TRUE(combined.has(functionality1));
250       ASSERT_TRUE(combined.has(functionality2));
251       // and it has the original two runtime keys
252       ASSERT_TRUE(combined.has(tid1));
253       ASSERT_TRUE(combined.has(tid2));
254 
255       // Add all of the keys in the keyset to a real set
256       std::unordered_set<DispatchKey> visited_keys;
257       auto iter = combined.begin();
258       while (*iter != *combined.end()) {
259         visited_keys.insert(*iter);
260         ++iter;
261       }
262       std::unordered_set<DispatchKey> expected_keys;
263       expected_keys.insert(
264           toRuntimePerBackendFunctionalityKey(functionality1, backend1));
265       expected_keys.insert(
266           toRuntimePerBackendFunctionalityKey(functionality1, backend2));
267       expected_keys.insert(
268           toRuntimePerBackendFunctionalityKey(functionality2, backend1));
269       expected_keys.insert(
270           toRuntimePerBackendFunctionalityKey(functionality2, backend2));
271       ASSERT_EQ(expected_keys, visited_keys);
272 
273       if (backend1 == backend2 || functionality1 == functionality2) {
274         // We have two runtime keys, with either the same backend or the same
275         // per-backend functionalities. E.g. {AutogradCUDA, CUDA} or
276         // {AutogradCPU, AutogradCUDA} There should be 2 total runtime keys in
277         // this set.
278         ASSERT_EQ(2, visited_keys.size());
279       } else {
280         // since i and j are different keys, they should not have the same
281         // functionality and backend
282         ASSERT_TRUE(backend1 != backend2 && functionality1 != functionality2);
283         // We have two runtime keys, that have different backends + per-backend
284         // functionalities. So we should expect the full cross product of
285         // runtime keys to be in the set. e.g. if i = AutogradCUDA, and j = CPU,
286         // then combined = {AutogradCUDA, AutogradCPU, CUDA, CPU}
287         ASSERT_EQ(4, visited_keys.size());
288       }
289     }
290   }
291 }
292 
TEST(DispatchKeySet,Full)293 TEST(DispatchKeySet, Full) {
294   DispatchKeySet full(DispatchKeySet::FULL);
295   for (const auto i : c10::irange(1, num_functionality_keys)) {
296     auto tid = static_cast<DispatchKey>(i);
297     ASSERT_TRUE(full.has(tid));
298   }
299   ASSERT_FALSE(full.has(DispatchKey::EndOfFunctionalityKeys));
300 }
301 
TEST(DispatchKeySet,IteratorBasicOps)302 TEST(DispatchKeySet, IteratorBasicOps) {
303   DispatchKeySet empty_set;
304   DispatchKeySet full_set(DispatchKeySet::FULL);
305   DispatchKeySet mutated_set = empty_set.add(DispatchKey::CPU);
306 
307   // Constructor + Comparison
308   ASSERT_EQ(*empty_set.begin(), DispatchKey::EndOfFunctionalityKeys);
309   ASSERT_EQ(*empty_set.end(), DispatchKey::EndOfFunctionalityKeys);
310   ASSERT_EQ(*mutated_set.begin(), DispatchKey::CPU);
311 
312   ASSERT_TRUE(empty_set.begin() == empty_set.end());
313   ASSERT_TRUE(full_set.begin() != full_set.end());
314 
315   // Increment Ops
316   ASSERT_TRUE(full_set.begin() == full_set.begin()++);
317   ASSERT_TRUE(full_set.begin() != ++full_set.begin());
318 }
319 
TEST(DispatchKeySet,getHighestPriorityBackendTypeId)320 TEST(DispatchKeySet, getHighestPriorityBackendTypeId) {
321   // AutogradCPU isn't a backend key so it is ignored
322   DispatchKeySet dense_cpu({DispatchKey::AutogradCPU, DispatchKey::CPU});
323   ASSERT_EQ(DispatchKey::CPU, c10::highestPriorityBackendTypeId(dense_cpu));
324 
325   // Functionalize isn't a backend key so it is ignored
326   DispatchKeySet sparse_cuda(
327       {DispatchKey::Functionalize, DispatchKey::SparseCUDA});
328   ASSERT_EQ(
329       DispatchKey::SparseCUDA, c10::highestPriorityBackendTypeId(sparse_cuda));
330 
331   DispatchKeySet sparse_compressed_cuda(
332       {DispatchKey::Functionalize, DispatchKey::SparseCsrCUDA});
333   ASSERT_EQ(
334       DispatchKey::SparseCsrCUDA,
335       c10::highestPriorityBackendTypeId(sparse_compressed_cuda));
336 
337   // quantizedCUDA has higher priority than CUDA
338   DispatchKeySet quantized_cuda(
339       {DispatchKey::CUDA, DispatchKey::QuantizedCUDA});
340   ASSERT_EQ(
341       DispatchKey::QuantizedCUDA,
342       c10::highestPriorityBackendTypeId(quantized_cuda));
343 }
344 
TEST(DispatchKeySet,IteratorEmpty)345 TEST(DispatchKeySet, IteratorEmpty) {
346   DispatchKeySet empty_set;
347   uint8_t i = 0;
348 
349   for (auto it = empty_set.begin(); it != empty_set.end(); ++it) {
350     i++;
351   }
352   ASSERT_EQ(i, 0);
353 }
354 
TEST(DispatchKeySet,IteratorCrossProduct)355 TEST(DispatchKeySet, IteratorCrossProduct) {
356   // The iterator should return all runtime keys in the set,
357   // including the cross product of {backends} x {functionalities}
358   auto ks =
359       DispatchKeySet({BackendComponent::CPUBit, BackendComponent::CUDABit}) |
360       DispatchKeySet(
361           {DispatchKey::Dense,
362            DispatchKey::FPGA,
363            DispatchKey::AutogradFunctionality});
364 
365   auto iter = ks.begin();
366   // iterate through dense backends first.
367   ASSERT_EQ(DispatchKey::CPU, *(iter++));
368   ASSERT_EQ(DispatchKey::CUDA, *(iter++));
369   // FPGA doesn't have a backend bit, so it isn't included in the cross product.
370   ASSERT_EQ(DispatchKey::FPGA, *(iter++));
371   // iterate through the autograd keys laster.
372   ASSERT_EQ(DispatchKey::AutogradCPU, *(iter++));
373   ASSERT_EQ(DispatchKey::AutogradCUDA, *(iter++));
374 }
375 
TEST(DispatchKeySet,IteratorFull)376 TEST(DispatchKeySet, IteratorFull) {
377   DispatchKeySet full_set(DispatchKeySet::FULL);
378   std::ptrdiff_t count = std::distance(full_set.begin(), full_set.end());
379 
380   // Total # of runtime entries includes an entry for DispatchKey::Undefined,
381   // which is not included when iterating through the DispatchKeySet.
382   ASSERT_EQ(count, std::ptrdiff_t{num_runtime_entries} - 1);
383 }
TEST(DispatchKeySet,FailAtEndIterator)384 TEST(DispatchKeySet, FailAtEndIterator) {
385   DispatchKeySet full_set(DispatchKeySet::FULL);
386   uint64_t raw_repr = full_set.raw_repr();
387 
388   // doesn't throw
389   DispatchKeySet::iterator(&raw_repr, num_backends + num_functionality_keys);
390   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
391   EXPECT_THROW(
392       DispatchKeySet::iterator(
393           &raw_repr, num_backends + num_functionality_keys + 1),
394       c10::Error);
395 }
396 
TEST(DispatchKeySet,TestBackendComponentToString)397 TEST(DispatchKeySet, TestBackendComponentToString) {
398   std::unordered_set<std::string> seen_strings;
399   for (int64_t i = 0;
400        i <= static_cast<int64_t>(BackendComponent::EndOfBackendKeys);
401        i++) {
402     auto k = static_cast<BackendComponent>(i);
403     auto res = std::string(toString(k));
404     ASSERT_FALSE(res == "UNKNOWN_BACKEND_BIT");
405     ASSERT_FALSE(seen_strings.count(res) > 0);
406     seen_strings.insert(res);
407   }
408 }
409 
TEST(DispatchKeySet,TestEndOfRuntimeBackendKeysAccurate)410 TEST(DispatchKeySet, TestEndOfRuntimeBackendKeysAccurate) {
411   DispatchKey k = DispatchKey::Undefined;
412 #define SETTER(fullname, prefix) k = DispatchKey::EndOf##fullname##Backends;
413   C10_FORALL_FUNCTIONALITY_KEYS(SETTER)
414 #undef SETTER
415   ASSERT_TRUE(k == DispatchKey::EndOfRuntimeBackendKeys);
416 }
417 
TEST(DispatchKeySet,TestFunctionalityDispatchKeyToString)418 TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) {
419   std::unordered_set<std::string> seen_strings;
420   for (int i = 0; i <= static_cast<int>(DispatchKey::EndOfAliasKeys); i++) {
421     auto k = static_cast<DispatchKey>(i);
422     // These synthetic keys never actually get used and don't need
423     // to be printed
424     if (k == DispatchKey::EndOfFunctionalityKeys ||
425         k == DispatchKey::StartOfDenseBackends ||
426         k == DispatchKey::StartOfQuantizedBackends ||
427         k == DispatchKey::StartOfSparseBackends ||
428         k == DispatchKey::StartOfSparseCsrBackends ||
429         k == DispatchKey::StartOfNestedTensorBackends ||
430         k == DispatchKey::StartOfAutogradFunctionalityBackends)
431       continue;
432     auto res = std::string(toString(k));
433     ASSERT_TRUE(res.find("Unknown") == std::string::npos)
434         << i << " (before is " << toString(static_cast<DispatchKey>(i - 1))
435         << ")";
436     ASSERT_TRUE(seen_strings.count(res) == 0);
437     seen_strings.insert(res);
438   }
439 }
440