xref: /aosp_15_r20/external/pytorch/c10/core/DispatchKeySet.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/core/DispatchKey.h>
3 #include <c10/macros/Export.h>
4 #include <c10/macros/Macros.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/Metaprogramming.h>
7 #include <c10/util/TypeList.h>
8 #include <c10/util/llvmMathExtras.h>
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <initializer_list>
13 #include <iterator>
14 #include <ostream>
15 #include <string>
16 #include <type_traits>
17 
18 namespace c10 {
19 
20 struct FunctionalityOffsetAndMask {
21   // empty constructor shouldn't be used; only needed to initialize
22   // the array before populating it.
23   FunctionalityOffsetAndMask() = default;
FunctionalityOffsetAndMaskFunctionalityOffsetAndMask24   FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask)
25       : offset(offset), mask(mask) {}
26   // This needs to big enough to cover the size of the operator table.
27   uint16_t offset{};
28   // See Note [No More Than 16 Backends]
29   // This mask needs to be big enough to mask all of the backend bits.
30   // We probably don't ever want to have more than 16 backend bits, so uint16_t
31   // should be enough.
32   uint16_t mask{};
33 };
34 static_assert(
35     c10::num_runtime_entries < 65536,
36     "The dispatcher currently only supports up to 2^16 runtime entries");
37 
38 C10_API std::array<FunctionalityOffsetAndMask, num_functionality_keys>
39 initializeFunctionalityOffsetsAndMasks();
40 
41 C10_ALWAYS_INLINE static const std::
42     array<FunctionalityOffsetAndMask, num_functionality_keys>&
offsetsAndMasks()43     offsetsAndMasks() {
44   static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks();
45   return offsets_and_masks_;
46 }
47 
48 // A representation of a set of DispatchKeys. A DispatchKeySet contains both
49 // "functionality" bits and "backend bits", and every tensor holds its own
50 // DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the
51 // keyset on every input tensor, or’ing them together, and dispatching to a
52 // specific piece of functionality. The functionality bits are *ordered*. When
53 // multiple functionality bits are set, we use the highest priority
54 // functionality. Similarly, multiple backend bits can theoretically be set if
55 // you call an operator with multiple tensors from difference devices (e.g. CPU
56 // and CUDA), although support for mixed device dispatch is limited (the only
57 // kernels that gracefully handle mixed device inputs for now are cuda kernels
58 // that take in a scalar cpu tensor).
59 
60 // A representation of a set of DispatchKeys.  A tensor may have multiple
61 // tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the
62 // DispatchKeySet specifies what type ids apply.  The internal representation is
63 // as a 64-bit bit set (this means only 64 tensor type ids are supported).
64 //
65 // As mentioned above, DispatchKeys are ordered; thus, we can ask questions like
66 // "what is the highest priority DispatchKey in the set"?  (The set itself is
67 // not ordered; two sets with the same ids will always have the ids ordered in
68 // the same way.)
69 //
70 // Note [DispatchKeySet Internal Representation]
71 // Internally, dispatch keys are packed into 64-bit DispatchKeySet objects
72 // that get passed around at runtime.
73 // However, there isn't necessarily a 1-to-1 mapping between bits in the keyset
74 // and individual dispatch keys.
75 //
76 // First: why do we have this distinction, and why not map every dispatch key
77 // directly to a bit? This is mostly because we have several types of
78 // functionalities that different backends would like to customize. For example,
79 // we have:
80 // - "Dense":     CPU, CUDA, XLA, ... (~12 keys)
81 // - "Sparse":    SparseCPU, SparseCUDA, ...
82 // - "SparseCsr": SparseCsrCPU, SparseCsrCUDA, ...
83 // - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ...
84 // - "Autograd":  AutogradCPU, AutogradCUDA, Autograd XLA, ...
85 // The problem is that total number of keys grows quadratically with [#
86 // backends] x [# functionalities], making it very difficult to map each key
87 // directly to a bit in a bitset without dramatically increasing the size of the
88 // bitset over time.
89 //
90 // The two enums (BackendComponent and DispatchKey) can be divided roughly into
91 // 5 categories.
92 //
93 // (1) "Building block" keys
94 //    (a) backends: Everything in the BackendComponent enum (e.g. CPUBit,
95 //    CUDABit) (b) functionalities: (per-backend) functionality-bit DispatchKeys
96 //    (e.g. AutogradFunctionality, SparseCsr, Sparse, Dense)
97 // (2) "Runtime" keys
98 //    (a) "non-customizable backends" (e.g. FPGA)
99 //    (b) "non-customizable functionalities" (e.g. Functionalize)
100 //    (c) "per-backend instances of customizable functionalities" (e.g. CPU,
101 //    SparseCPU, AutogradCPU)
102 // (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys])
103 //
104 // (1) Building block keys always correspond to individual bits in a
105 // DispatchKeySet. They can also be combined in a DispatchKeySet to form actual
106 // runtime keys. e.g.
107 //     auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit,
108 //     DispatchKey::Dense});
109 //     // The keyset has the runtime dense-cpu key.
110 //     dense_cpu_ks.has(DispatchKey::CPU);
111 //     // And it contains the building block keys too.
112 //     dense_cpu_ks.has(DispatchKey::CPUBit);
113 //     dense_cpu_ks.has(DispatchKey::Dense);
114 //
115 // Not every backend and not every functionality counts as a "building block
116 // key". This is mostly to give us more levers to pull in the design space.
117 // Backend keys and functionality keys that count as "building blocks" will
118 // contribute to a full cross product of functionality that can be overriden.
119 //
120 // For example, right now we have at least 12 "backend" building
121 // blocks (CPU, CUDA, XLA, ...) and at least 5 "functionality"
122 // building blocks (Dense, Sparse, SparseCsr, Quantized,
123 // AutogradFunctionality, ...). These keys together allow every
124 // dispatcher operator to be customized in up to 12*4 different
125 // ways. Each of those requires a slot in the operator table of every
126 // dispatcher operator.  Not every piece of functionality necessarily
127 // needs to be customizable per-backend, and not every backend
128 // necessarily needs to be able to customize every type of
129 // functionality.
130 //
131 //
132 // (2) Every runtime key corresponds directly to a slot in an operator's runtime
133 // dispatch table, and you can directly register kernels to a runtime dispatch
134 // key.
135 //
136 // For per-backend functionalities like "Dense" or "AutogradFunctionality",
137 // you can think of the corresponding runtime dispatch keys as "instances" of
138 // that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all
139 // runtime instances of the "Dense" building block key.
140 
141 // (2a) and (2b) are represented identically in the DispatchKeySet logic:
142 // - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT
143 // customizable per backend.
144 //   In order to do so, we'd need to promote it to a per-backend functionality
145 //   "building block" key.
146 // - non-customizable backends (e.g. FPGA) can NOT customize existing
147 // functionality like Sparse, Autograd, etc.
148 //   In order to do so, we'd need to promote it to a backend "building block"
149 //   key.
150 //
151 // In both cases, these keys directly correspond to runtime slots in the
152 // operator table.
153 //
154 //
155 // (3) "Alias" keys
156 // See Note [Alias Dispatch Keys]
157 //
158 // Final note: for anyone making future changes to the Dispatcher +
159 // DispatchKeySet internals, there's a closed PR with a basic
160 // python-implementation of the Dispatcher that might be useful in quickly
161 // testing out and validating changes. See it at
162 // https://github.com/pytorch/pytorch/pull/68743
163 
164 // An undefined tensor is one with an empty tensor type set.
165 class DispatchKeySet final {
166  public:
167   enum Full { FULL };
168   enum FullAfter { FULL_AFTER };
169   enum Raw { RAW };
170 
171   // NB: default constructor representation as zero is MANDATORY as
172   // use of DispatchKeySet in TLS requires this.
173   constexpr DispatchKeySet() = default;
174 
DispatchKeySet(Full)175   constexpr DispatchKeySet(Full)
176       : repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {}
177 
DispatchKeySet(FullAfter,DispatchKey t)178   constexpr DispatchKeySet(FullAfter, DispatchKey t)
179       // LSB after t are OK, but not t itself.
180       // "functionalities" have a notion of ordering (e.g. Autograd > Sparse >
181       // Quantized > Dense). But backends don't really have an ordering.
182       // Therefore, we're enforcing that FullAfter can only be used on
183       // "functionality" keys.
184       : repr_(
185             (1ULL
186              << (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) -
187                  1)) -
188             1) {
189     *this = add(DispatchKey::PythonDispatcher);
190   }
191 
192   // Public version of DispatchKeySet(uint64_t) API; external users
193   // must be explicit when they do this!
DispatchKeySet(Raw,uint64_t x)194   constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {}
195 
DispatchKeySet(BackendComponent k)196   constexpr explicit DispatchKeySet(BackendComponent k) {
197     if (k == BackendComponent::InvalidBit) {
198       repr_ = 0;
199     } else {
200       repr_ = 1ULL << (static_cast<uint8_t>(k) - 1);
201     }
202   }
203 
DispatchKeySet(DispatchKey k)204   constexpr explicit DispatchKeySet(DispatchKey k) {
205     // NOLINTNEXTLINE(bugprone-branch-clone)
206     if (k == DispatchKey::Undefined) {
207       // Case 1: handle Undefined specifically
208       repr_ = 0;
209     } else if (k <= DispatchKey::EndOfFunctionalityKeys) {
210       // Case 2: handle "functionality-only" keys
211       // These keys have a functionality bit set, but no backend bits
212       // These can technically be either:
213       // - valid runtime keys (e.g. DispatchKey::AutogradOther,
214       // DispatchKey::FuncTorchBatched, etc)
215       // - "building block" keys that aren't actual runtime keys (e.g.
216       // DispatchKey::Dense or Sparse)
217       uint64_t functionality_val = 1ULL
218           << (num_backends + static_cast<uint8_t>(k) - 1);
219       repr_ = functionality_val;
220     } else if (k <= DispatchKey::EndOfRuntimeBackendKeys) {
221       // Case 3: "runtime" keys that have a functionality bit AND a backend bit.
222       // First compute which bit to flip for the functionality.
223       auto functionality_k = toFunctionalityKey(k);
224       // The - 1 is because Undefined is technically a "functionality" that
225       // doesn't show up in the bitset. So e.g. Dense is technically the second
226       // functionality, but the lowest functionality bit.
227       uint64_t functionality_val = 1ULL
228           << (num_backends + static_cast<uint8_t>(functionality_k) - 1);
229 
230       // then compute which bit to flip for the backend
231       // Case 4a: handle the runtime instances of "per-backend functionality"
232       // keys For example, given DispatchKey::CPU, we should set:
233       // - the Dense functionality bit
234       // - the CPUBit backend bit
235       // first compute which bit to flip for the backend
236       auto backend_k = toBackendComponent(k);
237       uint64_t backend_val = backend_k == BackendComponent::InvalidBit
238           ? 0
239           : 1ULL << (static_cast<uint8_t>(backend_k) - 1);
240       repr_ = functionality_val + backend_val;
241     } else {
242       // At this point, we should have covered every case except for alias keys.
243       // Technically it would be possible to add alias dispatch keys to a
244       // DispatchKeySet, but the semantics are a little confusing and this
245       // currently isn't needed anywhere.
246       repr_ = 0;
247     }
248   }
249 
keys_to_repr(std::initializer_list<DispatchKey> ks)250   constexpr uint64_t keys_to_repr(std::initializer_list<DispatchKey> ks) {
251     uint64_t repr = 0;
252     for (auto k : ks) {
253       repr |= DispatchKeySet(k).repr_;
254     }
255     return repr;
256   }
257 
backend_bits_to_repr(std::initializer_list<BackendComponent> ks)258   constexpr uint64_t backend_bits_to_repr(
259       std::initializer_list<BackendComponent> ks) {
260     uint64_t repr = 0;
261     for (auto k : ks) {
262       repr |= DispatchKeySet(k).repr_;
263     }
264     return repr;
265   }
266 
DispatchKeySet(std::initializer_list<DispatchKey> ks)267   explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
268       : repr_(keys_to_repr(ks)) {}
269 
DispatchKeySet(std::initializer_list<BackendComponent> ks)270   explicit constexpr DispatchKeySet(std::initializer_list<BackendComponent> ks)
271       // Note: for some reason, putting this logic directly in the constructor
272       // appears to fail to compile on CUDA 10.1.
273       // See an example internal failure at
274       // https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr
275       : repr_(backend_bits_to_repr(ks)) {}
276 
277   // Test if a DispatchKey is in the set
has(DispatchKey t)278   inline bool has(DispatchKey t) const {
279     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined);
280     return has_all(DispatchKeySet(t));
281   }
has_backend(BackendComponent t)282   constexpr bool has_backend(BackendComponent t) const {
283     return has_all(DispatchKeySet(t));
284   }
285 
286   // Test if a DispatchKey is in the set
287   // Given a DispatchKeySet of functionality keys and (potentially) backend
288   // keys, tests if all of them are in the current set.
has_all(DispatchKeySet ks)289   constexpr bool has_all(DispatchKeySet ks) const {
290     return static_cast<bool>((repr_ & ks.repr_) == ks.repr_);
291   }
292 
293   // Given a DispatchKeySet of functionality keys and (potentially) backend
294   // keys, tests if any of them are in the current set. This could technically
295   // be pretty easily implemented using has(). It is strictly a perf
296   // optimization though. There are many places in the code base where we want
297   // to test for multiple functionality keys together. HOWEVER, runtime
298   // per-backend functionality keys aren't allowed to be used with this
299   // function, because you can end up with weird results. e.g.
300   // DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU))
301   // would return true.
has_any(DispatchKeySet ks)302   inline bool has_any(DispatchKeySet ks) const {
303     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
304         // Either there are no backend bits in the input keyset
305         ((ks.repr_ & full_backend_mask) == 0) ||
306         // or there are no per-backend-functionality bits
307         // See [Note: Per-Backend Functionality Dispatch Keys]
308         ((ks &
309           DispatchKeySet({
310                              DispatchKey::Dense,
311                              DispatchKey::Quantized,
312                              DispatchKey::Sparse,
313                              DispatchKey::SparseCsr,
314                              DispatchKey::AutogradFunctionality,
315                          })
316               .repr_) == 0));
317     return static_cast<bool>((repr_ & ks.repr_) != 0);
318   }
319   // Test if DispatchKeySet is a superset of ks.
isSupersetOf(DispatchKeySet ks)320   bool isSupersetOf(DispatchKeySet ks) const {
321     return (repr_ & ks.repr_) == ks.repr_;
322   }
323   // Perform set union
324   constexpr DispatchKeySet operator|(DispatchKeySet other) const {
325     return DispatchKeySet(repr_ | other.repr_);
326   }
327   // Perform set intersection
328   constexpr DispatchKeySet operator&(DispatchKeySet other) const {
329     return DispatchKeySet(repr_ & other.repr_);
330   }
331   // Compute the set difference self - other,
332   // but ONLY for the functionality keys.
333   // Any backend bits set on self will remain unchanged.
334   // See Note [Removing keys from DispatchKeySet Only Affects Functionality
335   // Keys]
336   constexpr DispatchKeySet operator-(DispatchKeySet other) const {
337     return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_));
338   }
339 
340   // Compute self ^ other
341   constexpr DispatchKeySet operator^(DispatchKeySet other) const {
342     return DispatchKeySet(repr_ ^ other.repr_);
343   }
344   bool operator==(DispatchKeySet other) const {
345     return repr_ == other.repr_;
346   }
347   bool operator!=(DispatchKeySet other) const {
348     return repr_ != other.repr_;
349   }
350   // Add a DispatchKey to the DispatchKey set.  Does NOT mutate,
351   // returns the extended DispatchKeySet!
add(DispatchKey t)352   C10_NODISCARD constexpr DispatchKeySet add(DispatchKey t) const {
353     return *this | DispatchKeySet(t);
354   }
add(DispatchKeySet ks)355   C10_NODISCARD constexpr DispatchKeySet add(DispatchKeySet ks) const {
356     return *this | ks;
357   }
358 
359   // Remove a DispatchKey from the DispatchKey set.
360   // This is generally not an operation you should be doing
361   // (it's used to implement the printing overload, operator<<)
362   //
363   // Note [Removing keys from DispatchKeySet Only Affects Functionality Keys]
364   // Only functionality bits are allowed to be removed from a keyset.
365   // For now, we're only allowing removal of "functionality bits" from the
366   // keyset, which is specifically needed by the fallthrough key calculation
367   // logic. Why is removing backend bits problematic? Consider this example:
368   //
369   // DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA,
370   // DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA)
371   // DispatchKeySet([DispatchKey.CPU,
372   // DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA)
373   //
374   // What do we want to happen?
375   // Technically, we'd like it to be true that after removal,
376   // the first keyset still has the CUDA dispatch key while the second doesn't.
377   // Unfortunately there's no way to represent that, because the two keysets are
378   // represented the same way internally: functionality bits: Autograd, Dense
379   // backend bits: CPU, CUDA
380   //
381   // Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd"
382   // bit from the bitset.
remove(DispatchKey t)383   C10_NODISCARD constexpr DispatchKeySet remove(DispatchKey t) const {
384     return DispatchKeySet(
385         repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask));
386   }
387   // You're allowed to remove a backend bit from a DispatchKeySet,
388   // but you have to be explicit about it (remove_backend() instead of
389   // remove()).
remove_backend(BackendComponent b)390   constexpr DispatchKeySet remove_backend(BackendComponent b) const {
391     return DispatchKeySet(repr_ & ~(DispatchKeySet(b).repr_));
392   }
393   // Is the set empty?  (AKA undefined tensor)
empty()394   bool empty() const {
395     return repr_ == 0;
396   }
raw_repr()397   uint64_t raw_repr() {
398     return repr_;
399   }
400 
highestFunctionalityKey()401   DispatchKey highestFunctionalityKey() const {
402     auto functionality_idx = indexOfHighestBit();
403     // This means that none of the functionality bits were set.
404     if (functionality_idx < num_backends)
405       return DispatchKey::Undefined;
406     // The first num_backend bits in the keyset don't correspond to real
407     // dispatch keys.
408     return static_cast<DispatchKey>(functionality_idx - num_backends);
409   }
410 
411   // This is similar like toBackendComponent(DispatchKey), but less restrictive.
412   // toBackendComponent() errors out if the key that it was passed has no
413   // backend bits, which is useful for error checking. We need a version of that
414   // here that can also handle "fake" backends like FPGA, because they need to
415   // map to the AutogradOther key. For those backends, we return
416   // BackendComponent::InvalidBit.
highestBackendKey()417   BackendComponent highestBackendKey() const {
418     // mask to mask out functionality bits
419     auto backend_idx =
420         DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit();
421     // all zeros across the backend bits means that no backend bits are set.
422     if (backend_idx == 0)
423       return BackendComponent::InvalidBit;
424     return static_cast<BackendComponent>(backend_idx);
425   }
426 
427   // returns the DispatchKey of highest priority in the set.
highestPriorityTypeId()428   DispatchKey highestPriorityTypeId() const {
429     auto functionality_k = highestFunctionalityKey();
430     if (isPerBackendFunctionalityKey(functionality_k)) {
431       return toRuntimePerBackendFunctionalityKey(
432           functionality_k, highestBackendKey());
433     }
434     return functionality_k;
435   }
436 
437   // Returns the index of the most-significant bit in the keyset.
438   // This is used to as part of the calculation into the operator table to get:
439   // - the highest "functionality" bit in the keyset.
440   // - the highest "backend" bit in the keyset.
indexOfHighestBit()441   uint8_t indexOfHighestBit() const {
442     return 64 - llvm::countLeadingZeros(repr_);
443   }
444 
445 #if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
446   // [Note: Trimmed Mobile Dispatch Keys]
447   /**
448    * The method below maps the dispatch key in the enum DispatchKey to an
449    * integer index in the dispatchTable_ array in OperatorEntry. The array
450    * is trimmed for mobile to reduce peak memory usage since it's
451    * unnecessary to reserve additional space for dispatch keys that will
452    * never be used on mobile.
453    */
getDispatchTableIndexForDispatchKeySet()454   int getDispatchTableIndexForDispatchKeySet() const {
455     auto dk = highestPriorityTypeId();
456     switch (dk) {
457       case DispatchKey::Undefined:
458         return 0;
459       case DispatchKey::CPU:
460         return 1;
461       case DispatchKey::QuantizedCPU:
462         return 2;
463       case DispatchKey::SparseCPU:
464         return 3;
465       case DispatchKey::BackendSelect:
466         return 4;
467       case DispatchKey::ADInplaceOrView:
468         return 5;
469       case DispatchKey::AutogradOther:
470         return 6;
471       case DispatchKey::AutogradCPU:
472         return 7;
473       default:
474         return -1;
475     }
476   }
477 #else
478   // returns the index in the operator table of highest priority key in the the
479   // keyset Note that we could in theory implement this using
480   // highestPriorityTypeId(), but this code is very hotpath and we can do it
481   // faster without it.
getDispatchTableIndexForDispatchKeySet()482   int getDispatchTableIndexForDispatchKeySet() const {
483     auto functionality_idx =
484         DispatchKeySet(repr_ >> num_backends).indexOfHighestBit();
485     auto offset_and_mask = offsetsAndMasks()[functionality_idx];
486     // Mask the functionality bits out first, then right-shift by 1.
487     // right-shifting by 1 because everything is zero-indexed.
488     // E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should
489     // give us an offset of 1, etc.
490     auto backend_idx =
491         DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit();
492     return offset_and_mask.offset + backend_idx;
493   }
494 #endif
495 
496   // returns the "index" of the highest priority backend in the keyset.
497   // This is pretty similar to getBackendKey(), but:
498   // - It's hotpath code (part of the runtime bitset calculation)
499   // - I's returns an integer index, not an enum value
500   // - Everything is shifted to the right by 1.
501   //   BackendComponent::InvalidBit is technically the lowest enum value,
502   //   but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2,
503   //   etc.
getBackendIndex()504   uint64_t getBackendIndex() const {
505     return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit();
506   }
507 
508  private:
DispatchKeySet(uint64_t repr)509   constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {}
510   uint64_t repr_ = 0;
511 
512  public:
513   // STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys
514   // in the set. The iterator is only invalidated by the destruction of the
515   // underlying DispatchKeySet as the iterator stores a pointer to the raw
516   // representation of the DispatchKeySet. Note: When we encounter a per-backend
517   // functionality (e.g. Dense or Sparse), we will iterate through EVERY backend
518   // in the keyset, for that functionality. For example, if the next
519   // functionality key to iterate over is Autograd, and the backend bits in the
520   // keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit],
521   // then the next two keys we return will be DispatchKey::AutogradCPU,
522   // DispatchKey::AutogradCUDA (CPU first because it has lower precedence than
523   // CUDA in DispatchKey.h).
524   class iterator {
525    public:
526     using self_type = iterator;
527     using iterator_category = std::input_iterator_tag;
528     using value_type = DispatchKey;
529     using difference_type = ptrdiff_t;
530     using reference = value_type&;
531     using pointer = value_type*;
532     // final mask value should mask out the entire keyset
533     static const uint8_t end_iter_mask_val =
534         num_backends + num_functionality_keys;
535     // final key value should be the last DispatchKey
536     static const uint8_t end_iter_key_val = num_functionality_keys;
537 
538     // current_dispatchkey_idx_ will iterate through all functionality bits.
539     // current_backendcomponent_idx_ will iterate through all backend bits.
540     explicit iterator(
541         const uint64_t* data_ptr,
542         uint8_t next_functionality = num_backends,
543         uint8_t next_backend = 0)
data_ptr_(data_ptr)544         : data_ptr_(data_ptr),
545           next_functionality_(next_functionality),
546           next_backend_(next_backend),
547           // These are in an invalid state at construction time, and set by the
548           // first increment call
549           current_dispatchkey_idx_(end_iter_key_val),
550           current_backendcomponent_idx_(end_iter_key_val) {
551       // Go to the first key in the set
552       TORCH_INTERNAL_ASSERT(
553           next_functionality_ >= num_backends,
554           "num_backends=",
555           static_cast<uint32_t>(num_backends),
556           "next_functionality_=",
557           static_cast<uint32_t>(next_functionality_));
558       ++(*this);
559     }
560 
561     C10_API self_type& operator++();
562 
563     self_type operator++(int) {
564       self_type previous_iterator = *this;
565       ++(*this);
566       return previous_iterator;
567     }
568 
569     bool operator==(const self_type& rhs) const {
570       return next_functionality_ == rhs.next_functionality_ &&
571           current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ &&
572           next_backend_ == rhs.next_backend_ &&
573           current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_;
574     }
575     bool operator!=(const self_type& rhs) const {
576       return next_functionality_ != rhs.next_functionality_ ||
577           current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ ||
578           next_backend_ != rhs.next_backend_ ||
579           current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_;
580     }
581     DispatchKey operator*() const {
582       auto functionality_key =
583           static_cast<DispatchKey>(current_dispatchkey_idx_);
584       if (isPerBackendFunctionalityKey(functionality_key)) {
585         auto next_key = toRuntimePerBackendFunctionalityKey(
586             functionality_key,
587             static_cast<BackendComponent>(current_backendcomponent_idx_));
588         // We expect all of the Dense, Sparse, Quantized, and Autograd keys to
589         // be ordered the same way with respect to their backends
590         TORCH_INTERNAL_ASSERT(
591             toBackendComponent(next_key) ==
592                 static_cast<BackendComponent>(current_backendcomponent_idx_),
593             "Tried to map functionality key ",
594             toString(functionality_key),
595             " and backend bit ",
596             toString(
597                 static_cast<BackendComponent>(current_backendcomponent_idx_)),
598             " to a runtime key, but ended up with ",
599             toString(next_key),
600             ". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent.",
601             " Please double check that enum for inconsistencies.");
602         return next_key;
603       } else {
604         return functionality_key;
605       }
606     }
607 
608    private:
609     const uint64_t* data_ptr_;
610     uint8_t next_functionality_;
611     uint8_t next_backend_;
612     uint8_t current_dispatchkey_idx_;
613     uint8_t current_backendcomponent_idx_;
614   };
615 
616  public:
617   // Returns iterator to the first key in the set. If no keys are in the
618   // set, then will return the end iterator.
begin()619   iterator begin() const {
620     return iterator(&repr_);
621   }
622 
623   // We do not need to iterate beyond EndOfFunctionalityKeys so we will treat
624   // this as the end iterator.
end()625   iterator end() const {
626     return iterator(&repr_, iterator::end_iter_mask_val);
627   }
628 };
629 
630 C10_API std::string toString(DispatchKeySet);
631 C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet);
632 
getDispatchTableIndexForDispatchKey(DispatchKey k)633 C10_API inline int getDispatchTableIndexForDispatchKey(DispatchKey k) {
634   return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet();
635 }
636 
637 // Alias key DispatchKey::Autograd maps to
638 // (autograd_dispatch_keyset x full_backend_mask)
639 // NB: keys in this set also get associated with CompositeImplicitAutograd
640 //
641 // Note [autograd_dispatch_keyset Does Not Include Backend Bits]
642 // We don't want to include any backend bits (BackendComponent::CPUBit, etc)
643 // directly in autograd_dispatch_keyset.
644 // Why? keysets like autograd_dispatch_keyset are commonly used to remove
645 // autograd keys from a DispatchKeySet throughout the code base. However, you
646 // are only allowed to remove functionality bits from a keyset, not backend
647 // bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality
648 // Keys] for details. To be consistent and avoid confusion, we're explicitly
649 // setting up autograd_dispatch_keyset to not have any backend bits.
650 constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
651     DispatchKey::AutogradFunctionality,
652     DispatchKey::AutogradOther,
653     DispatchKey::AutogradNestedTensor,
654 });
655 
656 constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
657     DispatchKey::AutocastCPU,
658     DispatchKey::AutocastMPS,
659     DispatchKey::AutocastCUDA,
660     DispatchKey::AutocastXPU,
661     DispatchKey::AutocastIPU,
662     DispatchKey::AutocastHPU,
663     DispatchKey::AutocastXLA,
664     DispatchKey::AutocastPrivateUse1,
665 });
666 
667 // See Note [TLS Initialization]
668 constexpr DispatchKeySet default_included_set = DispatchKeySet({
669     DispatchKey::BackendSelect,
670     DispatchKey::ADInplaceOrView,
671 });
672 
673 constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
674     DispatchKey::AutocastCPU,
675     DispatchKey::AutocastMPS,
676     DispatchKey::AutocastCUDA,
677     DispatchKey::AutocastXPU,
678     DispatchKey::AutocastIPU,
679     DispatchKey::AutocastHPU,
680     DispatchKey::AutocastXLA,
681     DispatchKey::AutocastPrivateUse1,
682 });
683 
684 constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
685     autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView);
686 
687 constexpr DispatchKeySet python_ks = DispatchKeySet({
688     DispatchKey::Python,
689     DispatchKey::PythonTLSSnapshot,
690 });
691 
692 constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse);
693 
694 constexpr DispatchKeySet sparse_csr_ks = DispatchKeySet(DispatchKey::SparseCsr);
695 
696 constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU);
697 
698 // backend dispatch keys that map to DispatchKey::AutogradOther
699 // NB: keys in this set also get associated with CompositeImplicitAutograd
700 constexpr DispatchKeySet autogradother_backends =
701     DispatchKeySet(
702         // HIP and VE aren't in this list: they now have their own backend bits
703         // which means that they can now have their own Autograd keys.
704         // Technically, HIP will now redispatch to its own custom AutogradHIP
705         // slot in the runtime table.
706         {DispatchKey::FPGA,
707          DispatchKey::MAIA,
708          DispatchKey::Vulkan,
709          DispatchKey::Metal,
710          DispatchKey::CustomRNGKeyId,
711          DispatchKey::MkldnnCPU,
712          // Sparse and Quantized backends also live here.
713          DispatchKey::Sparse,
714          DispatchKey::SparseCsr,
715          DispatchKey::Quantized})
716     // Including the backend bits because this keyset is used during op
717     // registration, which requires looping over all runtime autogradother
718     // backend keys.
719     | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
720 
721 // The set of dispatch keys that come after autograd
722 // n.b. this relies on the fact that AutogradOther is currently the lowest
723 // Autograd key
724 constexpr DispatchKeySet after_autograd_keyset =
725     DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther);
726 
727 // The set of dispatch keys that come after ADInplaceOrView
728 constexpr DispatchKeySet after_ADInplaceOrView_keyset = DispatchKeySet(
729     DispatchKeySet::FULL_AFTER,
730     c10::DispatchKey::ADInplaceOrView);
731 
732 // The set of dispatch keys that come after Functionalize
733 constexpr DispatchKeySet after_func_keyset =
734     DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Functionalize)
735         .remove(
736             // NOTE: we also need to remove ADInplaceOrView from the keyset when
737             // redispatching after the func kernels. This is because we're not
738             // calling the same op; we originally called an inplace op, and now
739             // we aren't. The original key calculation figured out which keys
740             // were Fallthrough based on the inplace op. That means that it did
741             // not include the ADInPlaceOrView kernel as a fallthrough key.
742             // However, we WANT the ADInPlaceOrView kernel to be ignored now
743             // that we're calling an out-of-place op. Re-invoking
744             // Dispatcher::call would re-run the Fallthrough key calculation and
745             // get us that, But at::redispatch is more performant. We can get
746             // away with it by explicitly removing the key here.
747             c10::DispatchKey::ADInplaceOrView);
748 
749 constexpr DispatchKeySet backend_bitset_mask =
750     DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1);
751 
752 constexpr auto inplace_or_view_ks =
753     DispatchKeySet(DispatchKey::ADInplaceOrView);
754 constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
755 constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU);
756 constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
757 constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
758 constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
759 constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy);
760 constexpr auto autograd_meta_ks = DispatchKeySet(DispatchKey::AutogradMeta);
761 constexpr auto autograd_mps_ks = DispatchKeySet(DispatchKey::AutogradMPS);
762 constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU);
763 constexpr auto autograd_privateuse1_ks =
764     DispatchKeySet(DispatchKey::AutogradPrivateUse1);
765 constexpr auto autograd_privateuse2_ks =
766     DispatchKeySet(DispatchKey::AutogradPrivateUse2);
767 constexpr auto autograd_privateuse3_ks =
768     DispatchKeySet(DispatchKey::AutogradPrivateUse3);
769 constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther);
770 constexpr auto autograd_nested =
771     DispatchKeySet(DispatchKey::AutogradNestedTensor);
772 // keyset corresponding to functorch keys that have their own dedicated
773 // TensorImpl subclass.
774 constexpr auto functorch_transforms_ks = DispatchKeySet(
775     {DispatchKey::FuncTorchBatched,
776      DispatchKey::FuncTorchVmapMode,
777      DispatchKey::Batched,
778      DispatchKey::VmapMode,
779      DispatchKey::FuncTorchGradWrapper});
780 
781 constexpr auto functorch_batched_ks =
782     DispatchKeySet({DispatchKey::FuncTorchBatched});
783 
784 // This keyset has:
785 // (1) the functionality bits corresponding to backends (dense, sparse,
786 // quantized) (2) all of the backend bits set
787 constexpr DispatchKeySet backend_functionality_keys =
788     DispatchKeySet({
789         DispatchKey::Dense,
790         DispatchKey::Quantized,
791         DispatchKey::Sparse,
792         DispatchKey::SparseCsr,
793     }) |
794     DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
795 
796 struct OpTableOffsetAndMask {
797   uint16_t offset;
798   uint16_t backend_mask;
799 };
800 
801 static_assert(
802     num_backends <= 16,
803     "Right now we expect the number of backends not to exceed 16. In the (unlikely) event"
804     " that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too.");
805 
806 // true if t is a backend dispatch key
807 C10_API bool isBackendDispatchKey(DispatchKey t);
808 
809 // Resolve alias dispatch key to DispatchKeySet if applicable
810 C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t);
811 
812 // Resolve alias dispatch key to DispatchKeySet if applicable,
813 // and check if k is a part of that set
814 C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k);
815 
816 // Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key
817 // t, DispatchKeySet is empty if t is not alias of DispatchKey::Autograd.
818 C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);
819 
820 // Returns a DispatchKeySet of autograd related keys mapped to backend.
821 // for a given backend key, use the associated autograd key.
822 // for non-backend keys, use AutogradOther as a default.
823 // Note: it's convenient and fast to return a default here rather than (say)
824 // returning an std::optional<DispatchKey>, or throwing. But it makes callers
825 // responsible for either a) enforcing the invariant that only backend keys
826 // be passed as arguments, or b) interpreting our return value carefully.
getAutogradRelatedKeySetFromBackend(BackendComponent t)827 inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
828   switch (t) {
829     case BackendComponent::CPUBit:
830       return inplace_or_view_ks | autograd_cpu_ks;
831     case BackendComponent::IPUBit:
832       return inplace_or_view_ks | autograd_ipu_ks;
833     case BackendComponent::XPUBit:
834       return inplace_or_view_ks | autograd_xpu_ks;
835     case BackendComponent::CUDABit:
836       return inplace_or_view_ks | autograd_cuda_ks;
837     case BackendComponent::XLABit:
838       return inplace_or_view_ks | autograd_xla_ks;
839     case BackendComponent::LazyBit:
840       return inplace_or_view_ks | autograd_lazy_ks;
841     case BackendComponent::MetaBit:
842       return inplace_or_view_ks | autograd_meta_ks;
843     case BackendComponent::MPSBit:
844       return inplace_or_view_ks | autograd_mps_ks;
845     case BackendComponent::HPUBit:
846       return inplace_or_view_ks | autograd_hpu_ks;
847     case BackendComponent::PrivateUse1Bit:
848       return inplace_or_view_ks | autograd_privateuse1_ks;
849     case BackendComponent::PrivateUse2Bit:
850       return inplace_or_view_ks | autograd_privateuse2_ks;
851     case BackendComponent::PrivateUse3Bit:
852       return inplace_or_view_ks | autograd_privateuse3_ks;
853     default:
854       return inplace_or_view_ks | autograd_other_ks;
855   }
856 }
857 
858 // Returns a DispatchKeySet of autocast related keys mapped to backend.
getAutocastRelatedKeySetFromBackend(BackendComponent t)859 inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
860   constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU);
861   constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU);
862   constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU);
863   constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
864   constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
865   constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA);
866   constexpr auto autocast_privateuse1_ks =
867       DispatchKeySet(DispatchKey::AutocastPrivateUse1);
868   constexpr auto autocast_mps_ks = DispatchKeySet(DispatchKey::AutocastMPS);
869   switch (t) {
870     case BackendComponent::CPUBit:
871       return autocast_cpu_ks;
872     case BackendComponent::XPUBit:
873       return autocast_xpu_ks;
874     case BackendComponent::IPUBit:
875       return autocast_ipu_ks;
876     case BackendComponent::HPUBit:
877       return autocast_hpu_ks;
878     case BackendComponent::CUDABit:
879       return autocast_cuda_ks;
880     case BackendComponent::XLABit:
881       return autocast_xla_ks;
882     case BackendComponent::PrivateUse1Bit:
883       return autocast_privateuse1_ks;
884     case BackendComponent::MPSBit:
885       return autocast_mps_ks;
886     default:
887       return DispatchKeySet();
888   }
889 }
890 
891 // returns the "backend" DispatchKey of highest priority in the set.
892 // This is basically like highestBackendKey(), except that we have some
893 // "functionality" bits that correspond to backends (Sparse, Quantized)
highestPriorityBackendTypeId(DispatchKeySet ks)894 inline DispatchKey highestPriorityBackendTypeId(DispatchKeySet ks) {
895   return (ks & backend_functionality_keys).highestPriorityTypeId();
896 }
897 
898 // This API exists because we have a use case for checking
899 // getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined)
900 // in OperatorEntry.cpp but we disallow it in has() API.
901 C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias);
902 
903 // Historically, every tensor only had a single DispatchKey, and it was always
904 // something like CPU, and there wasn't any of this business where TLS
905 // could cause the DispatchKey of a tensor to change.  But we still have some
906 // legacy code that is still using DispatchKey for things like instanceof
907 // checks; if at all possible, refactor the code to stop using DispatchKey in
908 // those cases.
legacyExtractDispatchKey(DispatchKeySet s)909 inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
910   // NB: If you add any extra keys that can be stored in TensorImpl on
911   // top of existing "backend" keys like CPU/CUDA, you need to add it
912   // here.  At the moment, autograd keys and ADInplaceOrView key need this
913   // treatment;
914   return (s - autograd_dispatch_keyset_with_ADInplaceOrView -
915           autocast_dispatch_keyset -
916           DispatchKeySet(
917               {DispatchKey::Functionalize,
918                DispatchKey::PythonTLSSnapshot,
919                DispatchKey::FuncTorchGradWrapper,
920                DispatchKey::FuncTorchVmapMode,
921                DispatchKey::FuncTorchBatched,
922                DispatchKey::Python}))
923       .highestPriorityTypeId();
924 }
925 
926 template <class T>
927 using is_not_DispatchKeySet = std::negation<std::is_same<DispatchKeySet, T>>;
928 
929 // Given a function type, constructs a function_traits type that drops the first
930 // parameter type if the first parameter is of type DispatchKeySet. NB:
931 // DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid
932 // pushing unnecessary arguments on the stack - see Note [ Plumbing Keys Through
933 // the Dispatcher] for details). If at any point in the future we need to expose
934 // this type to JIT, revisit the usage of this type alias.
935 template <class FuncType>
936 using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t<
937     typename guts::infer_function_traits_t<FuncType>::return_type,
938     typename std::conditional_t<
939         std::is_same_v<
940             DispatchKeySet,
941             typename guts::typelist::head_with_default_t<
942                 void,
943                 typename guts::infer_function_traits_t<
944                     FuncType>::parameter_types>>,
945         guts::typelist::drop_if_nonempty_t<
946             typename guts::infer_function_traits_t<FuncType>::parameter_types,
947             1>,
948         typename guts::infer_function_traits_t<FuncType>::parameter_types>>;
949 } // namespace c10
950