1 #include <gtest/gtest.h>
2
3 #include <cstddef>
4 #include <iterator>
5 #include <unordered_set>
6
7 #include <c10/core/DispatchKeySet.h>
8 #include <c10/util/irange.h>
9
10 using namespace c10;
11
12 // This test exists not to be comprehensive, but to more clearly show
13 // what the semantics of DispatchKeySet are.
TEST(DispatchKeySet,ShowSemantics)14 TEST(DispatchKeySet, ShowSemantics) {
15 // the "CPU" dispatch key is an instance of a per-backend-functionality key.
16 // It corresponds to "dense" functionality, "CPU" backend.
17 // This means that it gets a dense functionality bit, and a cpu backend bit
18 // set.
19 auto dense_cpu_set = DispatchKeySet(DispatchKey::CPU);
20 ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense));
21 ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit));
22 ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU));
23
24 auto dense_lazy_set = DispatchKeySet(DispatchKey::Lazy);
25 ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Dense));
26 ASSERT_TRUE(dense_lazy_set.has_backend(BackendComponent::LazyBit));
27 ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Lazy));
28
29 // You can think of "Dense/Sparse", and "CPUBit/CUDABit", as "building block"
30 // dispatch keys. You are allowed to directly create keysets out of them!
31 auto dense_cpu_set_from_building_blocks = DispatchKeySet(DispatchKey::Dense) |
32 DispatchKeySet(BackendComponent::CPUBit);
33 ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense));
34 ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit));
35 ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU));
36 ASSERT_EQ(dense_cpu_set, dense_cpu_set_from_building_blocks);
37
38 // Similarly, the AutogradCUDA key gets 2 bits in the keyset:
39 // The "Autograd" functionality bit, and the "CUDA" backend bit
40 auto autograd_cuda = DispatchKeySet(DispatchKey::AutogradCUDA);
41 ASSERT_TRUE(autograd_cuda.has(DispatchKey::AutogradFunctionality));
42 ASSERT_TRUE(autograd_cuda.has_backend(BackendComponent::CUDABit));
43
44 // Because DispatchKeySet uses a condensed internal representation, you cannot
45 // use it to represent the FULL cross product of backends and functionalities
46 // for example:
47 auto autograd_dense_cpu_cuda = DispatchKeySet(
48 {DispatchKey::AutogradFunctionality,
49 DispatchKey::Dense,
50 DispatchKey::CUDA,
51 DispatchKey::CPU});
52 // this keyset has all of the building block keys:
53 ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradFunctionality));
54 ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::Dense));
55 ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CUDABit));
56 ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CPUBit));
57
58 // and it also has the "runtime" keys that correspond to the full
59 // cross-product of functionality
60 ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU));
61 ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU));
62 ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CPU));
63 ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CUDA));
64
65 // This means that there's no way to represent a keyset with, say, only
66 // Autograd CUDA + Dense CPU. Instead, you should think of a keyset as
67 // inheriting the full set of functionalities + backends of its keys. This
68 // means that the below keysets are all indistinguishable from each other.
69 ASSERT_EQ(
70 autograd_dense_cpu_cuda,
71 DispatchKeySet(
72 {DispatchKey::AutogradCUDA,
73 DispatchKey::AutogradCPU,
74 DispatchKey::CUDA,
75 DispatchKey::CPU}));
76 ASSERT_EQ(
77 autograd_dense_cpu_cuda,
78 DispatchKeySet({DispatchKey::AutogradCUDA, DispatchKey::CPU}));
79 ASSERT_EQ(
80 autograd_dense_cpu_cuda,
81 DispatchKeySet({DispatchKey::CUDA, DispatchKey::AutogradCPU}));
82
83 // ~~~~~~~~~~ DispatchKeySet iterators ~~~~~~~~~~~
84
85 // Iterators allow you to iterate individually through the DispatchKey's in a
86 // DispatchKeySet
87 auto empty_set = DispatchKeySet();
88 ASSERT_EQ(*empty_set.begin(), *empty_set.end());
89
90 // However, only keys that correspond to actual runtime indices of kernels in
91 // the operator table show up when you iterate through a keyset. i.e.
92 // DispatchKey::Dense, and BackendComponent::CPUBit won't show up in an
93 // iterator.
94 auto dense_cpu_iter = dense_cpu_set.begin();
95 ASSERT_EQ(*dense_cpu_iter++, DispatchKey::CPU);
96 ASSERT_EQ(*dense_cpu_iter, *dense_cpu_set.end());
97
98 auto autograd_dense_cpu_cuda_iter = autograd_dense_cpu_cuda.begin();
99 ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CPU);
100 ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CUDA);
101 ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCPU);
102 ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCUDA);
103 ASSERT_EQ(*autograd_dense_cpu_cuda_iter, *autograd_dense_cpu_cuda.end());
104
105 // But other "functionality bits" that are not defined per-backend DO get
106 // their own slots in the operator table.
107 auto mixed_keyset = DispatchKeySet(BackendComponent::CPUBit) |
108 DispatchKeySet(
109 {DispatchKey::FPGA, // runtime key
110 DispatchKey::Functionalize, // runtime key
111 DispatchKey::Dense}); // NOT a runtime key
112 auto mixed_iter = mixed_keyset.begin();
113 ASSERT_EQ(*mixed_iter++, DispatchKey::CPU);
114 ASSERT_EQ(*mixed_iter++, DispatchKey::FPGA);
115 ASSERT_EQ(*mixed_iter++, DispatchKey::Functionalize);
116 ASSERT_EQ(*mixed_iter, *mixed_keyset.end());
117 }
118
TEST(DispatchKeySet,Empty)119 TEST(DispatchKeySet, Empty) {
120 DispatchKeySet empty_set;
121 for (uint8_t i = 0;
122 i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
123 i++) {
124 auto tid = static_cast<DispatchKey>(i);
125 if (tid == DispatchKey::Undefined)
126 continue;
127 ASSERT_FALSE(empty_set.has(tid));
128 }
129 ASSERT_TRUE(empty_set.empty());
130 DispatchKeySet empty_set2;
131 ASSERT_TRUE(empty_set == empty_set2);
132 }
133
134 // This covers all keys that correspond to a single backend bit, e.g.
135 // BackendComponent::CPUBit. Even though these are NOT runtime keys, we still
136 // allow adding them directly to a keyset
TEST(DispatchKeySet,SingletonBackendComponent)137 TEST(DispatchKeySet, SingletonBackendComponent) {
138 for (const auto i : c10::irange(1, num_backends)) {
139 auto tid = static_cast<DispatchKey>(i);
140 DispatchKeySet sing(tid);
141 ASSERT_EQ(sing, sing);
142 ASSERT_EQ(sing, DispatchKeySet().add(tid));
143 ASSERT_EQ(sing, sing.add(tid));
144 ASSERT_EQ(sing, sing | sing);
145 ASSERT_FALSE(sing.empty());
146 ASSERT_TRUE(sing.has(tid));
147 }
148 }
149
150 // This covers all keys that correspond to a single functionality bit:
151 // - runtime, not-per-backend functionality keys, e.g.
152 // DispatchKey::FuncTorchBatched
153 // - runtime, "fake backend" keys, e.g. DispatchKey::FPGA
154 // - NOT-runtime, per-backend functionality keys, e.g. DispatchKey::Dense
155 // Even though it's not a runtime key, we still allow adding it directly to a
156 // keyset.
157 // DispatchKey::
TEST(DispatchKeySet,SingletonFunctionalityKeys)158 TEST(DispatchKeySet, SingletonFunctionalityKeys) {
159 for (const auto i : c10::irange(1, num_functionality_keys)) {
160 auto tid = static_cast<DispatchKey>(i);
161 DispatchKeySet sing(tid);
162 ASSERT_EQ(sing, sing);
163 ASSERT_EQ(sing, DispatchKeySet().add(tid));
164 ASSERT_EQ(sing, sing.add(tid));
165 ASSERT_EQ(sing, sing | sing);
166 ASSERT_FALSE(sing.empty());
167 ASSERT_TRUE(sing.has(tid));
168 ASSERT_EQ(sing.remove(tid), DispatchKeySet());
169 }
170 }
171
172 // This covers runtime keys that are per-backend,
173 // and take up more than one bit in a DispatchKeySet. They take up one
174 // functionality bit + one backend bit. e.g. CPU, CUDA, SparseCPU, SparseCUDA,
175 // AutogradCPU, AutogradCUDA
TEST(DispatchKeySet,SingletonPerBackendFunctionalityKeys)176 TEST(DispatchKeySet, SingletonPerBackendFunctionalityKeys) {
177 for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
178 i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
179 i++) {
180 auto tid = static_cast<DispatchKey>(i);
181 // Skip these because they aren't real keys.
182 if (tid == DispatchKey::StartOfDenseBackends ||
183 tid == DispatchKey::StartOfSparseBackends ||
184 tid == DispatchKey::StartOfQuantizedBackends ||
185 tid == DispatchKey::StartOfAutogradFunctionalityBackends) {
186 continue;
187 }
188 DispatchKeySet sing(tid);
189 ASSERT_EQ(sing, sing);
190 ASSERT_EQ(sing, DispatchKeySet().add(tid));
191 ASSERT_EQ(sing, sing.add(tid));
192 ASSERT_EQ(sing, sing | sing);
193 ASSERT_FALSE(sing.empty());
194 ASSERT_TRUE(sing.has(tid));
195
196 auto functionality_key = toFunctionalityKey(tid);
197 auto backend_key = toBackendComponent(tid);
198 // These two sets should be equivalent:
199 // DispatchKeySet(DispatchKey::CPU)
200 // DispatchKeySet({DispatchKey::Dense, BackendComponent::CPUBit})
201 auto expected_ks =
202 DispatchKeySet(functionality_key) | DispatchKeySet(backend_key);
203 ASSERT_EQ(sing, expected_ks);
204 // These two sets should be equivalent:
205 // DispatchKeySet(DispatchKey::CPU).remove(DispatchKey::Dense)
206 // DispatchKeySet(BackendComponent::CPUBit)
207 expected_ks = DispatchKeySet(toBackendComponent(tid));
208 ASSERT_EQ(sing.remove(tid), expected_ks);
209 }
210 }
211
TEST(DispatchKeySet,DoubletonPerBackend)212 TEST(DispatchKeySet, DoubletonPerBackend) {
213 for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
214 i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
215 i++) {
216 for (uint8_t j = i + 1;
217 j <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
218 j++) {
219 ASSERT_LT(i, j);
220 auto tid1 = static_cast<DispatchKey>(i);
221 auto tid2 = static_cast<DispatchKey>(j);
222
223 // Skip these because they aren't real keys.
224 if (tid1 == DispatchKey::StartOfDenseBackends ||
225 tid1 == DispatchKey::StartOfSparseBackends ||
226 tid1 == DispatchKey::StartOfSparseCsrBackends ||
227 tid1 == DispatchKey::StartOfQuantizedBackends ||
228 tid1 == DispatchKey::StartOfNestedTensorBackends ||
229 tid1 == DispatchKey::StartOfAutogradFunctionalityBackends)
230 continue;
231 if (tid2 == DispatchKey::StartOfDenseBackends ||
232 tid2 == DispatchKey::StartOfSparseBackends ||
233 tid2 == DispatchKey::StartOfSparseCsrBackends ||
234 tid2 == DispatchKey::StartOfQuantizedBackends ||
235 tid2 == DispatchKey::StartOfNestedTensorBackends ||
236 tid2 == DispatchKey::StartOfAutogradFunctionalityBackends)
237 continue;
238
239 auto backend1 = toBackendComponent(tid1);
240 auto backend2 = toBackendComponent(tid2);
241 auto functionality1 = toFunctionalityKey(tid1);
242 auto functionality2 = toFunctionalityKey(tid2);
243
244 auto combined = DispatchKeySet({tid1, tid2});
245 // The combined set has the backend bits
246 ASSERT_TRUE(combined.has_backend(backend1));
247 ASSERT_TRUE(combined.has_backend(backend2));
248 // and it has the backend bits
249 ASSERT_TRUE(combined.has(functionality1));
250 ASSERT_TRUE(combined.has(functionality2));
251 // and it has the original two runtime keys
252 ASSERT_TRUE(combined.has(tid1));
253 ASSERT_TRUE(combined.has(tid2));
254
255 // Add all of the keys in the keyset to a real set
256 std::unordered_set<DispatchKey> visited_keys;
257 auto iter = combined.begin();
258 while (*iter != *combined.end()) {
259 visited_keys.insert(*iter);
260 ++iter;
261 }
262 std::unordered_set<DispatchKey> expected_keys;
263 expected_keys.insert(
264 toRuntimePerBackendFunctionalityKey(functionality1, backend1));
265 expected_keys.insert(
266 toRuntimePerBackendFunctionalityKey(functionality1, backend2));
267 expected_keys.insert(
268 toRuntimePerBackendFunctionalityKey(functionality2, backend1));
269 expected_keys.insert(
270 toRuntimePerBackendFunctionalityKey(functionality2, backend2));
271 ASSERT_EQ(expected_keys, visited_keys);
272
273 if (backend1 == backend2 || functionality1 == functionality2) {
274 // We have two runtime keys, with either the same backend or the same
275 // per-backend functionalities. E.g. {AutogradCUDA, CUDA} or
276 // {AutogradCPU, AutogradCUDA} There should be 2 total runtime keys in
277 // this set.
278 ASSERT_EQ(2, visited_keys.size());
279 } else {
280 // since i and j are different keys, they should not have the same
281 // functionality and backend
282 ASSERT_TRUE(backend1 != backend2 && functionality1 != functionality2);
283 // We have two runtime keys, that have different backends + per-backend
284 // functionalities. So we should expect the full cross product of
285 // runtime keys to be in the set. e.g. if i = AutogradCUDA, and j = CPU,
286 // then combined = {AutogradCUDA, AutogradCPU, CUDA, CPU}
287 ASSERT_EQ(4, visited_keys.size());
288 }
289 }
290 }
291 }
292
TEST(DispatchKeySet,Full)293 TEST(DispatchKeySet, Full) {
294 DispatchKeySet full(DispatchKeySet::FULL);
295 for (const auto i : c10::irange(1, num_functionality_keys)) {
296 auto tid = static_cast<DispatchKey>(i);
297 ASSERT_TRUE(full.has(tid));
298 }
299 ASSERT_FALSE(full.has(DispatchKey::EndOfFunctionalityKeys));
300 }
301
TEST(DispatchKeySet,IteratorBasicOps)302 TEST(DispatchKeySet, IteratorBasicOps) {
303 DispatchKeySet empty_set;
304 DispatchKeySet full_set(DispatchKeySet::FULL);
305 DispatchKeySet mutated_set = empty_set.add(DispatchKey::CPU);
306
307 // Constructor + Comparison
308 ASSERT_EQ(*empty_set.begin(), DispatchKey::EndOfFunctionalityKeys);
309 ASSERT_EQ(*empty_set.end(), DispatchKey::EndOfFunctionalityKeys);
310 ASSERT_EQ(*mutated_set.begin(), DispatchKey::CPU);
311
312 ASSERT_TRUE(empty_set.begin() == empty_set.end());
313 ASSERT_TRUE(full_set.begin() != full_set.end());
314
315 // Increment Ops
316 ASSERT_TRUE(full_set.begin() == full_set.begin()++);
317 ASSERT_TRUE(full_set.begin() != ++full_set.begin());
318 }
319
TEST(DispatchKeySet,getHighestPriorityBackendTypeId)320 TEST(DispatchKeySet, getHighestPriorityBackendTypeId) {
321 // AutogradCPU isn't a backend key so it is ignored
322 DispatchKeySet dense_cpu({DispatchKey::AutogradCPU, DispatchKey::CPU});
323 ASSERT_EQ(DispatchKey::CPU, c10::highestPriorityBackendTypeId(dense_cpu));
324
325 // Functionalize isn't a backend key so it is ignored
326 DispatchKeySet sparse_cuda(
327 {DispatchKey::Functionalize, DispatchKey::SparseCUDA});
328 ASSERT_EQ(
329 DispatchKey::SparseCUDA, c10::highestPriorityBackendTypeId(sparse_cuda));
330
331 DispatchKeySet sparse_compressed_cuda(
332 {DispatchKey::Functionalize, DispatchKey::SparseCsrCUDA});
333 ASSERT_EQ(
334 DispatchKey::SparseCsrCUDA,
335 c10::highestPriorityBackendTypeId(sparse_compressed_cuda));
336
337 // quantizedCUDA has higher priority than CUDA
338 DispatchKeySet quantized_cuda(
339 {DispatchKey::CUDA, DispatchKey::QuantizedCUDA});
340 ASSERT_EQ(
341 DispatchKey::QuantizedCUDA,
342 c10::highestPriorityBackendTypeId(quantized_cuda));
343 }
344
TEST(DispatchKeySet,IteratorEmpty)345 TEST(DispatchKeySet, IteratorEmpty) {
346 DispatchKeySet empty_set;
347 uint8_t i = 0;
348
349 for (auto it = empty_set.begin(); it != empty_set.end(); ++it) {
350 i++;
351 }
352 ASSERT_EQ(i, 0);
353 }
354
TEST(DispatchKeySet,IteratorCrossProduct)355 TEST(DispatchKeySet, IteratorCrossProduct) {
356 // The iterator should return all runtime keys in the set,
357 // including the cross product of {backends} x {functionalities}
358 auto ks =
359 DispatchKeySet({BackendComponent::CPUBit, BackendComponent::CUDABit}) |
360 DispatchKeySet(
361 {DispatchKey::Dense,
362 DispatchKey::FPGA,
363 DispatchKey::AutogradFunctionality});
364
365 auto iter = ks.begin();
366 // iterate through dense backends first.
367 ASSERT_EQ(DispatchKey::CPU, *(iter++));
368 ASSERT_EQ(DispatchKey::CUDA, *(iter++));
369 // FPGA doesn't have a backend bit, so it isn't included in the cross product.
370 ASSERT_EQ(DispatchKey::FPGA, *(iter++));
371 // iterate through the autograd keys laster.
372 ASSERT_EQ(DispatchKey::AutogradCPU, *(iter++));
373 ASSERT_EQ(DispatchKey::AutogradCUDA, *(iter++));
374 }
375
TEST(DispatchKeySet,IteratorFull)376 TEST(DispatchKeySet, IteratorFull) {
377 DispatchKeySet full_set(DispatchKeySet::FULL);
378 std::ptrdiff_t count = std::distance(full_set.begin(), full_set.end());
379
380 // Total # of runtime entries includes an entry for DispatchKey::Undefined,
381 // which is not included when iterating through the DispatchKeySet.
382 ASSERT_EQ(count, std::ptrdiff_t{num_runtime_entries} - 1);
383 }
TEST(DispatchKeySet,FailAtEndIterator)384 TEST(DispatchKeySet, FailAtEndIterator) {
385 DispatchKeySet full_set(DispatchKeySet::FULL);
386 uint64_t raw_repr = full_set.raw_repr();
387
388 // doesn't throw
389 DispatchKeySet::iterator(&raw_repr, num_backends + num_functionality_keys);
390 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
391 EXPECT_THROW(
392 DispatchKeySet::iterator(
393 &raw_repr, num_backends + num_functionality_keys + 1),
394 c10::Error);
395 }
396
TEST(DispatchKeySet,TestBackendComponentToString)397 TEST(DispatchKeySet, TestBackendComponentToString) {
398 std::unordered_set<std::string> seen_strings;
399 for (int64_t i = 0;
400 i <= static_cast<int64_t>(BackendComponent::EndOfBackendKeys);
401 i++) {
402 auto k = static_cast<BackendComponent>(i);
403 auto res = std::string(toString(k));
404 ASSERT_FALSE(res == "UNKNOWN_BACKEND_BIT");
405 ASSERT_FALSE(seen_strings.count(res) > 0);
406 seen_strings.insert(res);
407 }
408 }
409
TEST(DispatchKeySet,TestEndOfRuntimeBackendKeysAccurate)410 TEST(DispatchKeySet, TestEndOfRuntimeBackendKeysAccurate) {
411 DispatchKey k = DispatchKey::Undefined;
412 #define SETTER(fullname, prefix) k = DispatchKey::EndOf##fullname##Backends;
413 C10_FORALL_FUNCTIONALITY_KEYS(SETTER)
414 #undef SETTER
415 ASSERT_TRUE(k == DispatchKey::EndOfRuntimeBackendKeys);
416 }
417
TEST(DispatchKeySet,TestFunctionalityDispatchKeyToString)418 TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) {
419 std::unordered_set<std::string> seen_strings;
420 for (int i = 0; i <= static_cast<int>(DispatchKey::EndOfAliasKeys); i++) {
421 auto k = static_cast<DispatchKey>(i);
422 // These synthetic keys never actually get used and don't need
423 // to be printed
424 if (k == DispatchKey::EndOfFunctionalityKeys ||
425 k == DispatchKey::StartOfDenseBackends ||
426 k == DispatchKey::StartOfQuantizedBackends ||
427 k == DispatchKey::StartOfSparseBackends ||
428 k == DispatchKey::StartOfSparseCsrBackends ||
429 k == DispatchKey::StartOfNestedTensorBackends ||
430 k == DispatchKey::StartOfAutogradFunctionalityBackends)
431 continue;
432 auto res = std::string(toString(k));
433 ASSERT_TRUE(res.find("Unknown") == std::string::npos)
434 << i << " (before is " << toString(static_cast<DispatchKey>(i - 1))
435 << ")";
436 ASSERT_TRUE(seen_strings.count(res) == 0);
437 seen_strings.insert(res);
438 }
439 }
440