xref: /aosp_15_r20/external/pytorch/c10/core/DispatchKey.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/DeviceType.h>
4 #include <c10/macros/Export.h>
5 #include <cstddef>
6 #include <cstdint>
7 #include <functional>
8 #include <ostream>
9 #include <string>
10 
11 namespace c10 {
12 
13 // Semantically, each value of BackendComponent identifies a "backend" for our
14 // dispatch. Some functionalities that we may dispatch to are allowed to
15 // register different handlers for each backend. The BackendComponent is then
16 // used to figure out which backend implementation to dispatch to.
17 
18 // In implementation terms, the backend component identifies a specific "bit" in
19 // a DispatchKeySet. The bits in the DispatchKeySet are split between the bottom
20 // ~12 "BackendComponent" bits, while the remaining upper bits are assigned to
21 // functionalities. When we encounter a functionality bit that is known to be
22 // customizable per-backend, then we also look at the lower BackendComponent
23 // bits and take the highest bit to determine which backend's implementation to
24 // use.
25 
26 // WARNING!  If you add a new backend component to the end of this list,
27 // make sure you register it before Meta.
28 // Meta must be at the end so that meta key in tls triggers meta kernels.
29 // (But you shouldn't: private use keys should have higher precedence than all
30 // built-in keys)
31 
32 // If you add a new (non-privateuse) backend here,
33 // make sure to add an Autograd<Backend> fallthrough kernel
34 // in aten/src/ATen/core/VariableFallbackKernel.cpp
35 
36 #define C10_FORALL_BACKEND_COMPONENTS(_, extra) \
37   _(CPU, extra)                                 \
38   _(CUDA, extra)                                \
39   _(HIP, extra)                                 \
40   _(XLA, extra)                                 \
41   _(MPS, extra)                                 \
42   _(IPU, extra)                                 \
43   _(XPU, extra)                                 \
44   _(HPU, extra)                                 \
45   _(VE, extra)                                  \
46   _(Lazy, extra)                                \
47   _(MTIA, extra)                                \
48   _(PrivateUse1, extra)                         \
49   _(PrivateUse2, extra)                         \
50   _(PrivateUse3, extra)                         \
51   _(Meta, extra)
52 
53 // WARNING!  If we add a new per-backend functionality key that has higher
54 // priority than Autograd, then make sure you update EndOfRuntimeBackendKeys
55 
56 #define C10_FORALL_FUNCTIONALITY_KEYS(_) \
57   _(Dense, )                             \
58   _(Quantized, Quantized)                \
59   _(Sparse, Sparse)                      \
60   _(SparseCsr, SparseCsr)                \
61   _(NestedTensor, NestedTensor)          \
62   _(AutogradFunctionality, Autograd)
63 
64 enum class BackendComponent : uint8_t {
65 
66   // A "backend" is colloquially used to refer to handlers for dispatch
67   // which actually implement the numerics of an operation in question.
68   //
69   // Due to the nature of the enum, these backends are specified in
70   // an ordered way, but for most backends this order is not semantically
71   // meaningful (e.g., it's valid to reorder these backends without changing
72   // semantics).  The only situation when backend ordering is meaningful
73   // is when the backend participates in multiple dispatch with another
74   // backend; e.g., CPU and CUDA (cuda must have higher priority).
75 
76   // These keys don't correspond to individual kernels.
77   // Instead, they represent the backends that are allowed to override specific
78   // pieces of functionality:
79   // - dense kernels (e.g. DispatchKey::CPU)
80   // - sparse kernels (e.g. DispatchKey::SparseCPU)
81   // - quantized kernels (e.g. DispatchKey::QuantizedCPU)
82   // - autograd kernels (e.g. DispatchKey::AutogradCPU)
83   // We reserve space in the runtime operator table for this full cross product
84   // of
85   // [backends in this enum] x [keys below that are explicitly marked as having
86   // per-backend functionality]
87   //
88   // A meta tensor is a tensor without any data associated with it.  (They
89   // have also colloquially been referred to as tensors on the "null" device).
90   // A meta tensor can be used to dry run operators without actually doing any
91   // computation, e.g., add on two meta tensors would give you another meta
92   // tensor with the output shape and dtype, but wouldn't actually add anything.
93 
94   InvalidBit = 0,
95 #define DEFINE_BACKEND_COMPONENT(n, _) n##Bit,
96   C10_FORALL_BACKEND_COMPONENTS(DEFINE_BACKEND_COMPONENT, unused)
97 #undef DEFINE_BACKEND_COMPONENT
98 
99   // Define an alias to represent end of backend dispatch keys.
100   // If you add new backend keys after PrivateUse3, please also update it here.
101   EndOfBackendKeys = MetaBit,
102 };
103 
104 // Semantically, a dispatch key identifies a possible "level" in our
105 // dispatch, for which a handler may be registered. Each handler corresponds
106 // to a type of functionality.
107 //
108 // In implementation terms, the dispatch key identifies a specific "bit" in a
109 // DispatchKeySet.  Higher bit indexes get handled by dispatching first (because
110 // we "count leading zeros" when we extract the highest priority dispatch
111 // key.)
112 //
113 // Note [DispatchKey Classification]
114 // This enum actually contains several types of keys, which are explained
115 // in more detail further down:
116 // (1) non-customizable backends (e.g. FPGA)
117 // (2) non-customizable functionalities (e.g. Functionalize)
118 // (3) functionalized that are customizable per backend (e.g. Dense, Sparse,
119 // AutogradFunctionality) (4) per-backend instances of customizable
120 // functionalities (e.g. CPU, SparseCPU, AutogradCPU) (5) alias keys (e.g.
121 // CompositeImplicitAutograd)
122 //
123 // Of the categories above, it's important to note:
124 // (a) which keys are assigned individual bits in a DispatchKeySet
125 // (b) which keys are assigned individual slots in the runtime operator table
126 // ("Runtime keys")
127 //
128 // (1), (2) and (3) all get their own dedicated bits in the DispatchKeySet.
129 // (1), (2) and (4) all get their own dedicated slots in the runtime operator
130 // table.
131 
132 // See Note [DispatchKeySet Internal Representation] for more details.
133 //
134 // NOTE: Keep the list in sync with `DispatchKey` in torchgen/model.py
135 enum class DispatchKey : uint16_t {
136 
137   // ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
138   // This is not a "real" functionality, but it exists to give us a "nullopt"
139   // element we can return for cases when a DispatchKeySet contains no elements.
140   // You can think a more semantically accurate definition of DispatchKey is:
141   //
142   //    using DispatchKey = std::optional<RealDispatchKey>
143   //
144   // and Undefined == nullopt.  We didn't actually represent
145   // it this way because std::optional<RealDispatchKey> would take two
146   // words, when DispatchKey fits in eight bits.
147 
148   Undefined = 0,
149 
150   // Define an alias for Undefined to represent CatchAll (long term
151   // this will get eliminated, but for now it's convenient)
152   CatchAll = Undefined,
153 
154   // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Functionality Keys ~~~~~~~~~~~~~~~~~~~~~~ //
155   // Every value in the enum (up to EndOfFunctionalityKeys)
156   // corresponds to an individual "functionality" that can be dispatched to.
157   // This is represented in the DispatchKeySet by assigning each of these enum
158   // values
159   // to each of the remaining (64 - len(BackendComponent)) bits.
160   //
161   // Most of these functionalities have a single handler assigned to them,
162   // making them "runtime keys".
163   // That map to a single slot in the runtime operator table.
164   //
165   // A few functionalities are allowed to be customizable per backend.
166   // See [Note: Per-Backend Functionality Dispatch Keys] for details.
167 
168   // See [Note: Per-Backend Functionality Dispatch Keys]
169   Dense,
170 
171   // Below are non-extensible backends.
172   // These are backends that currently don't have their own overrides for
173   // Autograd/Sparse/Quantized kernels,
174   // and we therefore don't waste space in the runtime operator table allocating
175   // space for them.
176   // If any of these backends ever need to customize, e.g., Autograd, then we'll
177   // need to add a DispatchKey::*Bit for them.
178 
179   // TODO: put this in BackendComponents
180   FPGA, // Xilinx support lives out of tree at
181   // https://gitlab.com/pytorch-complex/vitis_kernels
182 
183   // TODO: put this in BackendComponents
184   // MAIA backend lives out of tree
185   // - test/cpp_extensions/maia_extension.cpp
186   // - test/test_torch.py
187   // - aten/src/ATen/test/extension_backend_test.cpp
188   MAIA,
189 
190   Vulkan, // TODO: put this in BackendComponents
191   Metal, // TODO: put this in BackendComponents
192 
193   // See [Note: Per-Backend Functionality Dispatch Keys]
194   Quantized,
195 
196   // This backend is to support custom RNGs; it lets you go
197   // to a different kernel if you pass in a generator that is not a
198   // traditional CPUGeneratorImpl/CUDAGeneratorImpl.  To make use of this
199   // key:
200   //  1) set it as a second parameter of at::Generator constructor call in
201   //     the user-defined PRNG class.
202   //  2) use it as a dispatch key while registering custom kernels
203   //     (templatized kernels specialized for user-defined PRNG class)
204   // intended for out of tree use; tested by aten/src/ATen/test/rng_test.cpp
205   CustomRNGKeyId,
206 
207   // TODO: Make Mkldnn a functionality key, so we can give it Meta
208   // support
209   // Here are backends which specify more specialized operators
210   // based on the layout of the tensor.  Note that the sparse backends
211   // are one case where ordering matters: sparse multi-dispatches with
212   // the corresponding dense tensors, and must be handled before them.
213   MkldnnCPU, // registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp
214   // NB: not to be confused with MKLDNN, which is Caffe2 only
215 
216   // See [Note: Per-Backend Functionality Dispatch Keys]
217   Sparse,
218 
219   SparseCsr,
220 
221   NestedTensor,
222 
223   // In some situations, it is not immediately obvious what the correct
224   // backend for function is, because the function in question doesn't
225   // have any "tensor" arguments.  In this case, a BackendSelect function
226   // can be registered to implement the custom determination of the
227   // correct backend.
228   BackendSelect,
229 
230   Python,
231 
232   // Out-of-core key for Fake Tensor in torchdistx.
233   // See https://pytorch.org/torchdistx/latest/fake_tensor.html
234   // TODO: delete this in favor of Python-implemented fake tensor
235   Fake,
236   // See Note [Out-of-tree vmap+grad prototype]. The purpose of this key
237   // is to insert code after the "autograd subsystem" runs, so this key should
238   // be directly after ADInplaceOrView and all of the autograd keys.
239   FuncTorchDynamicLayerBackMode,
240 
241   // Alias and mutation removal.
242   // If some backends want to opt into only alias removal or only mutation
243   // removal,
244   // we can consider adding separate keys dedicated to those individual passes.
245   // See Note [Functionalization Pass In Core] for details.
246   Functionalize,
247 
248   // The named dispatch key is set for any tensors with named dimensions.
249   // Although we have a dispatch key for named tensors, for historical reasons,
250   // this dispatch key doesn't do any of the substantive functionality for named
251   // tensor (though, hypothetically, it could!)  At the moment, it's just
252   // responsible for letting us give good error messages when operations
253   // don't support named tensors.
254   //
255   // NB: If you ever consider moving named tensor functionality into
256   // this dispatch key, note that it might be necessary add another dispatch
257   // key that triggers before composite operators, in case a composite operator
258   // has named dimension propagation that doesn't match that of its
259   // constituent parts.
260   // TODO: delete this once torchdim lands in functorch
261   Named,
262 
263   // The Conjugate dispatch key is set for any tensors that need to perform
264   // conjugation
265   // This is implemented at a dispatch level right before any backends run
266   Conjugate,
267 
268   // The Negative dispatch key is set for any tensors that need to perform
269   // negation
270   // This is implemented at a dispatch level right before any backends run
271   Negative,
272 
273   ZeroTensor, // registered at build/aten/src/ATen/RegisterZeroTensor.cpp
274 
275   // Note [ADInplaceOrView key]
276   // ADInplaceOrView key is used by inplace or view ops to register a kernel
277   // that does additional setup for future autograd computation.
278   //
279   // 1. For inplace ops this kernel does version bump
280   // 2. For view ops this kernel does `as_view` setup where we properly setup
281   //    DifferentiableViewMeta on the view tensors.
282   //
283   // For other ops it's fallthrough kernel since there's no extra
284   // work to do.
285   //
286   // Note [Dream: skip VariableType kernel when requires_grad=false]
287   //
288   // In an ideal world where we can skip VariableType kernel for inputs
289   // with requires_grad=false, instead of a fallthrough kernel, we'll
290   // register a kernel shown below to all functional ops as well:
291   // torch::Tensor my_functional_op(...) {
292   //   {
293   //     // Note for every op in VariableType, you need to go through
294   //     // `AutoDispatchBelowADInplaceOrView` guard exactly once to add the
295   //     // key to TLS excluded set. If you don't go through it at all,
296   //     // inplace/view ops called through `at::` inside your backend
297   //     // kernel will dispatch to ADInplaceOrView kernels and do a lot
298   //     // of extra work.
299   //     at::AutoDispatchBelowADInplaceOrView guard;
300   //     at::redispatch::my_functional_op(...);
301   //   }
302   // }
303   // But this work is currently blocked since it adds an extra dispatch
304   // for all ops and it's non-trivial overhead at model level(a few percents).
305   // Thus our current approach takes advantage of the fact every kernel go
306   // through VariableType kernel first and pulls the
307   // `at::AutoDispatchBelowADInplaceOrView` guard of functional ops
308   // up to the `VariableType` kernel. Thus we only add the extra dispatch
309   // to view/inplace ops to minimize its perf impact to real models.
310   ADInplaceOrView,
311   // Note [Alias Dispatch Key : Autograd]
312   // All backends are oblivious to autograd; autograd is handled as a
313   // layer which happens on top of all backends. It inspects the autograd
314   // metadata of all inputs, determines what autograd metadata should be
315   // constructed by the output, and otherwise defers to the backend to
316   // actually do the numeric computation.  Autograd contains
317   // the bulk of this logic.
318 
319   // Autograd is now an alias dispatch key which by default maps to all
320   // backend-specific autograd keys.
321   // Backend-specific allow backends to override the default kernel registered
322   // to Autograd key as needed.
323   // For example, XLA wants to define autograd for einsum directly.
324   // Registering a custom autograd implementation at the XLA key won't work
325   // because we process Autograd before XLA.  This key has higher priority and
326   // gets processed first.  You generally should NOT redispatch after handling
327   // autograd here (since that would result in execution of the Autograd
328   // operator, which you're trying to skip).  In AutogradXLA implementations,
329   // you are responsible for handling autograd yourself, or deferring to other
330   // operators which support autograd.
331 
332   // Currently we only have backend-specific autograd keys for CPU/CUDA/XLA and
333   // reserved user-defined backends. All other in-tree backends share the
334   // AutogradOther key. We can add specific autograd key for those backends
335   // upon request.
336   AutogradOther,
337 
338   // See [Note: Per-Backend Functionality Dispatch Keys]
339   AutogradFunctionality,
340 
341   // NestedTensor is an example of something that isn't a "real backend"
342   // (because it mostly consists of redispatching kernels)
343   // but it would like to override autograd functionality in C++.
344   // We can handle cases like this by adding an extra functionality key
345   // exclusively for handling autograd for NestedTensor.
346   // lives out of tree at
347   // https://github.com/pytorch/nestedtensor
348   AutogradNestedTensor,
349 
350   Tracer,
351 
352   // TODO: make Autocast a functionality key
353   // Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed
354   // and inputs are saved for backward in the post-autocast type.
355   AutocastCPU,
356   AutocastXPU,
357   AutocastIPU,
358   AutocastHPU,
359   AutocastXLA,
360   // AutocastXLA is only being used for TPUs. XLA GPUs continue to use
361   // AutocastCUDA.
362   AutocastMPS,
363   AutocastCUDA,
364   AutocastPrivateUse1,
365 
366   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
367   // There are a number of alternative modes which may want to handle before
368   // autograd; for example, error checking, tracing, profiling or vmap.  They
369   // go here.
370 
371   FuncTorchBatched, // See Note [Out-of-tree vmap+grad prototype]
372 
373   // Dispatch key for BatchedTensorImpl wrapping a nested tensor.
374   BatchedNestedTensor,
375 
376   FuncTorchVmapMode, // See Note [Out-of-tree vmap+grad prototype]
377 
378   // This is the dispatch key for BatchedTensorImpl, which is used to implement
379   // batching rules for vmap.
380   Batched,
381 
382   // When we are inside a vmap, all tensors dispatch on this key.
383   // See Note: [DispatchKey::VmapMode usage] for more details.
384   VmapMode,
385 
386   FuncTorchGradWrapper, // See Note [Out-of-tree vmap+grad prototype]
387 
388   // Out-of-core key for Deferred Module Initialization in torchdistx.
389   // See https://pytorch.org/torchdistx/latest/deferred_init.html
390   DeferredInit,
391 
392   // Used by Python key logic to know the set of tls on entry to the dispatcher
393   // This kernel assumes it is the top-most non-functorch-related DispatchKey.
394   // If you add a key above, make sure to update the fallback implementation for
395   // this.
396   PythonTLSSnapshot,
397 
398   // This key should be at the very top of the dispatcher
399   FuncTorchDynamicLayerFrontMode, // See Note [Out-of-tree vmap+grad prototype]
400 
401   // TESTING: This is intended to be a generic testing tensor type id.
402   // Don't use it for anything real; its only acceptable use is within a single
403   // process test.  Use it by creating a TensorImpl with this DispatchKey, and
404   // then registering operators to operate on this type id.  See
405   // aten/src/ATen/core/dispatch/backend_fallback_test.cpp for a usage example.
406   TESTING_ONLY_GenericWrapper,
407 
408   // TESTING: This is intended to be a generic testing tensor type id.
409   // Don't use it for anything real; its only acceptable use is within a ingle
410   // process test.  Use it by toggling the mode on and off via
411   // TESTING_ONLY_tls_generic_mode_set_enabled and then registering operators
412   // to operate on this type id.  See
413   // aten/src/ATen/core/dispatch/backend_fallback_test.cpp
414   // for a usage example
415   TESTING_ONLY_GenericMode,
416 
417   // This key is used for pre-dispatch tracing in make_fx.
418   // It has lower priority than the PythonDispatcher key
419   // because we use the PythonDispatcher to intercept the key from python,
420   // and avoid having to implement it in C++.
421   PreDispatch,
422 
423   // This is a bypass that allows you to skip running the C++ dispatcher
424   // entirely
425   PythonDispatcher,
426 
427   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
428   EndOfFunctionalityKeys, // End of functionality keys.
429 
430 // ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ //
431 // Here are backends which you think of as traditionally specifying
432 // how to implement operations on some device.
433 
434 #define DEFINE_PER_BACKEND_KEYS_FOR_BACKEND(n, prefix) prefix##n,
435 
436 #define DEFINE_PER_BACKEND_KEYS(fullname, prefix)      \
437   StartOf##fullname##Backends,                         \
438       C10_FORALL_BACKEND_COMPONENTS(                   \
439           DEFINE_PER_BACKEND_KEYS_FOR_BACKEND, prefix) \
440           EndOf##fullname##Backends = prefix##Meta,
441 
442   C10_FORALL_FUNCTIONALITY_KEYS(DEFINE_PER_BACKEND_KEYS)
443 
444 #undef DEFINE_PER_BACKEND_KEYS
445 #undef DEFINE_PER_BACKEND_KEYS_FOR_BACKEND
446 
447       EndOfRuntimeBackendKeys = EndOfAutogradFunctionalityBackends,
448 
449   // ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ //
450   // Note [Alias Dispatch Keys]
451   // Alias dispatch keys are synthetic dispatch keys which map to multiple
452   // runtime dispatch keys. Alisa keys have precedence, but they are always
453   // lower precedence than runtime keys. You can register a kernel to an
454   // alias key, the kernel might be populated to the mapped runtime keys
455   // during dispatch table computation.
456   // If a runtime dispatch key has multiple kernels from alias keys, which
457   // kernel wins is done based on the precedence of alias keys (but runtime
458   // keys always have precedence over alias keys).
459   // Alias keys won't be directly called during runtime.
460 
461   // See Note [Alias Dispatch Key : Autograd]
462   Autograd,
463   CompositeImplicitAutograd, // registered at
464   // build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp
465 
466   // Note: The alias keyset for FuncTorchBatchedDecomposition is disjoint from
467   // all
468   // other alias keysets
469   // and so precedence order doesn't matter
470   FuncTorchBatchedDecomposition, // registered at
471   // build/aten/src/ATen/RegisterFuncTorchBatchedDecomposition.cpp
472   // Note: The alias keyset for CompositeImplicitAutogradNestedTensor is
473   // disjoint from all other alias keysets
474   CompositeImplicitAutogradNestedTensor, // registered at
475   // build/aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp
476   CompositeExplicitAutograd, // registered at
477   // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
478   // See Note [CompositeExplicitAutogradNonFunctional Key]
479   CompositeExplicitAutogradNonFunctional, // registered at
480   // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
481 
482   // Define an alias key to represent end of alias dispatch keys.
483   // If you add new alias keys after Autograd, please also update it here.
484   StartOfAliasKeys = Autograd,
485   EndOfAliasKeys = CompositeExplicitAutogradNonFunctional, //
486 
487   // ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
488   // The aliases exist for backwards compatibility reasons, they shouldn't
489   // be used
490   CPUTensorId = CPU,
491   CUDATensorId = CUDA,
492   DefaultBackend = CompositeExplicitAutograd,
493   PrivateUse1_PreAutograd = AutogradPrivateUse1,
494   PrivateUse2_PreAutograd = AutogradPrivateUse2,
495   PrivateUse3_PreAutograd = AutogradPrivateUse3,
496   Autocast = AutocastCUDA,
497 };
498 
499 // Note [Private use DispatchKey]
500 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~
501 // Private use tensor IDs are preallocated tensor type IDs for use in user
502 // applications.  Similar to private use fields in HTTP, they can be used
503 // by end users for experimental or private applications, without needing
504 // to "standardize" the tensor ID (which would be done by submitting a PR
505 // to PyTorch to add your type ID).
506 //
507 // Private use tensor IDs are appropriate to use if you want to experiment
508 // with adding a new tensor type (without having to patch PyTorch first) or
509 // have a private, non-distributed application that needs to make use of a
510 // new tensor type.  Private use tensor IDs are NOT appropriate to use for
511 // libraries intended to be distributed to further users: please contact
512 // the PyTorch developers to get a type ID registered in this case.
513 //
514 // We provide two classes of private user tensor id: regular DispatchKeys
515 // and Autograd DispatchKeys.  DispatchKeys serve the role of ordinary "backend"
516 // DispatchKeys; if you were adding support for a new type of accelerator, you
517 // would use a backend DispatchKey, and ideally automatically reuse
518 // AutogradOther definitions already defined in PyTorch.  AutogradPrivateUse
519 // DispatchKeys serve as "wrapper" DispatchKeys: they are only necessary for
520 // tensors that compose multiple internal tensors, and for cases when the
521 // built-in autograd formulas for operators are not appropriate.
522 
523 static_assert(
524     (static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) +
525      static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys)) <= 64,
526     "The BackendComponent and DispatchKey enums (below EndOfFunctionalityKeys)"
527     " both map to backend and functionality bits"
528     " into a 64-bit bitmask; you must have less than 64 total entries between them");
529 
530 // Check if a DispatchKey is an alias mapping to other runtime keys.
isAliasDispatchKey(DispatchKey k)531 constexpr bool isAliasDispatchKey(DispatchKey k) {
532   return k >= DispatchKey::StartOfAliasKeys && k <= DispatchKey::EndOfAliasKeys;
533 }
534 
535 // [Note: Per-Backend Functionality Dispatch Keys]
536 // Check if a DispatchKey is a per-backend functionality key
537 // Any functionalities that can be customized per-backend should be added here.
538 // These keys correspond to functionalities that can be customized individually
539 // per backend. While they only take up one bit in the `DispatchKeySet` bitset,
540 // they map to (# backends) slots in the operator table.
541 // Each of these keys also has a separate set of "runtime keys" in the dispatch
542 // key enum, per backend, which *do* map to the individual operator table slots.
543 // For example, the "Sparse" key maps to an individual bit in the
544 // DispatchKeySet, while `SparseCPU`, `SparseCUDA`, etc all map to individual
545 // slots in the runtime operator table.
546 
isPerBackendFunctionalityKey(DispatchKey k)547 constexpr bool isPerBackendFunctionalityKey(DispatchKey k) {
548   if (k == DispatchKey::Dense || k == DispatchKey::Quantized ||
549       k == DispatchKey::Sparse || k == DispatchKey::SparseCsr ||
550       k == DispatchKey::AutogradFunctionality ||
551       k == DispatchKey::NestedTensor) {
552     return true;
553   } else {
554     return false;
555   }
556 }
557 
558 // Note that this includes Undefined in the total count.
559 // BUT EndOfFunctionalityKeys is its own (placeholder) key.
560 // e.g. Undefined=0, Dense=1, Sparse=2, EndOfFunctionalityKeys=3.
561 // In the above example, there are 3 total functionality keys.
562 constexpr uint8_t num_functionality_keys =
563     static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys);
564 
565 constexpr uint8_t num_backends =
566     static_cast<uint8_t>(BackendComponent::EndOfBackendKeys);
567 
568 // Note [No More Than 16 Backends]
569 // Search for this note to find places in the code where the "no more than 16
570 // backends" invariant is baked in.
571 static_assert(
572     static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) <= 16,
573     "BackendComponent currently only supports <= 16 backends. If we really need to extend this, \
574 there are a few places where this invariant is baked in");
575 
numPerBackendFunctionalityKeys()576 constexpr uint8_t numPerBackendFunctionalityKeys() {
577   uint8_t count = 0;
578   for (uint8_t k = 0; k <= num_functionality_keys; ++k) {
579     if (isPerBackendFunctionalityKey(static_cast<DispatchKey>(k)))
580       ++count;
581   }
582   return count;
583 }
584 
585 #if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
586 // See [Note: Trimmed Mobile Dispatch Keys]
587 constexpr uint16_t num_runtime_entries = 8;
588 #else
589 constexpr uint16_t num_runtime_entries = num_functionality_keys +
590     (numPerBackendFunctionalityKeys() * (num_backends - 1));
591 #endif
592 
593 // See Note [No More Than 16 Backends]
594 constexpr uint16_t full_backend_mask =
595     (static_cast<uint16_t>(1) << num_backends) - 1;
596 
597 C10_API const char* toString(DispatchKey);
598 C10_API const char* toString(BackendComponent);
599 C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
600 C10_API std::ostream& operator<<(std::ostream&, BackendComponent);
601 
602 C10_API DispatchKey getAutogradKeyFromBackend(BackendComponent k);
603 
604 // Parses a string into a dispatch key.
605 // If the string cannot be correctly parsed, throws an exception.
606 C10_API c10::DispatchKey parseDispatchKey(const std::string& k);
607 
608 // These are some convenience identifiers for dispatch keys which are
609 // shorter to type than their long counterparts.  Note that some of these
610 // dispatch keys directly correspond to DeviceType; and most APIs that
611 // accept DispatchKey also accept DeviceType; e.g.,
612 // torch::dispatch(torch::kCPU, ...) is also valid.
613 constexpr DispatchKey kAutograd = DispatchKey::Autograd;
614 
615 // See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
616 // This function relies on the invariant that the dispatch keys between
617 // StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
618 // in the same order as `BackendComponent`.
toBackendComponent(DispatchKey k)619 constexpr BackendComponent toBackendComponent(DispatchKey k) {
620   if (k >= DispatchKey::StartOfDenseBackends &&
621       k <= DispatchKey::EndOfDenseBackends) {
622     return static_cast<BackendComponent>(
623         static_cast<uint8_t>(k) -
624         static_cast<uint8_t>(DispatchKey::StartOfDenseBackends));
625   } else if (
626       k >= DispatchKey::StartOfQuantizedBackends &&
627       k <= DispatchKey::EndOfQuantizedBackends) {
628     return static_cast<BackendComponent>(
629         static_cast<uint8_t>(k) -
630         static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends));
631   } else if (
632       k >= DispatchKey::StartOfSparseBackends &&
633       k <= DispatchKey::EndOfSparseBackends) {
634     return static_cast<BackendComponent>(
635         static_cast<uint8_t>(k) -
636         static_cast<uint8_t>(DispatchKey::StartOfSparseBackends));
637   } else if (
638       k >= DispatchKey::StartOfSparseCsrBackends &&
639       k <= DispatchKey::EndOfSparseCsrBackends) {
640     return static_cast<BackendComponent>(
641         static_cast<uint8_t>(k) -
642         static_cast<uint8_t>(DispatchKey::StartOfSparseCsrBackends));
643   } else if (
644       k >= DispatchKey::StartOfNestedTensorBackends &&
645       k <= DispatchKey::EndOfNestedTensorBackends) {
646     return static_cast<BackendComponent>(
647         static_cast<uint8_t>(k) -
648         static_cast<uint8_t>(DispatchKey::StartOfNestedTensorBackends));
649   } else if (
650       k >= DispatchKey::StartOfAutogradFunctionalityBackends &&
651       k <= DispatchKey::EndOfAutogradFunctionalityBackends) {
652     return static_cast<BackendComponent>(
653         static_cast<uint8_t>(k) -
654         static_cast<uint8_t>(
655             DispatchKey::StartOfAutogradFunctionalityBackends));
656   } else {
657     return BackendComponent::InvalidBit;
658   }
659 }
660 
toFunctionalityKey(DispatchKey k)661 constexpr DispatchKey toFunctionalityKey(DispatchKey k) {
662   if (k <= DispatchKey::EndOfFunctionalityKeys) {
663     return k;
664   } else if (k <= DispatchKey::EndOfDenseBackends) {
665     return DispatchKey::Dense;
666   } else if (k <= DispatchKey::EndOfQuantizedBackends) {
667     return DispatchKey::Quantized;
668   } else if (k <= DispatchKey::EndOfSparseBackends) {
669     return DispatchKey::Sparse;
670   } else if (k <= DispatchKey::EndOfSparseCsrBackends) {
671     return DispatchKey::SparseCsr;
672   } else if (k <= DispatchKey::EndOfNestedTensorBackends) {
673     return DispatchKey::NestedTensor;
674   } else if (k <= DispatchKey::EndOfAutogradFunctionalityBackends) {
675     return DispatchKey::AutogradFunctionality;
676   } else {
677     return DispatchKey::Undefined;
678   }
679 }
680 
681 BackendComponent toBackendComponent(DeviceType device_type);
682 
683 // Given (DispatchKey::Dense, BackendComponent::CUDABit), returns
684 // DispatchKey::CUDA.
685 // See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
686 // This function relies on the invariant that the dispatch keys between
687 // StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
688 // in the same order as `BackendComponent`.
toRuntimePerBackendFunctionalityKey(DispatchKey functionality_k,BackendComponent backend_k)689 constexpr DispatchKey toRuntimePerBackendFunctionalityKey(
690     DispatchKey functionality_k,
691     BackendComponent backend_k) {
692   if (functionality_k == DispatchKey::Dense) {
693     return static_cast<DispatchKey>(
694         static_cast<uint8_t>(DispatchKey::StartOfDenseBackends) +
695         static_cast<uint8_t>(backend_k));
696   }
697   if (functionality_k == DispatchKey::Sparse) {
698     return static_cast<DispatchKey>(
699         static_cast<uint8_t>(DispatchKey::StartOfSparseBackends) +
700         static_cast<uint8_t>(backend_k));
701   }
702   if (functionality_k == DispatchKey::SparseCsr) {
703     return static_cast<DispatchKey>(
704         static_cast<uint8_t>(DispatchKey::StartOfSparseCsrBackends) +
705         static_cast<uint8_t>(backend_k));
706   }
707   if (functionality_k == DispatchKey::Quantized) {
708     return static_cast<DispatchKey>(
709         static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends) +
710         static_cast<uint8_t>(backend_k));
711   }
712   if (functionality_k == DispatchKey::NestedTensor) {
713     return static_cast<DispatchKey>(
714         static_cast<uint8_t>(DispatchKey::StartOfNestedTensorBackends) +
715         static_cast<uint8_t>(backend_k));
716   }
717   if (functionality_k == DispatchKey::AutogradFunctionality) {
718     return static_cast<DispatchKey>(
719         static_cast<uint8_t>(
720             DispatchKey::StartOfAutogradFunctionalityBackends) +
721         static_cast<uint8_t>(backend_k));
722   }
723   return DispatchKey::Undefined;
724 }
725 
726 } // namespace c10
727 
728 namespace torch {
729 // Expose the constant, but not the TYPE (DispatchKey is an implementation
730 // detail!)
731 // NOLINTNEXTLINE(misc-unused-using-decls)
732 using c10::kAutograd;
733 } // namespace torch
734 
735 // NB: You really shouldn't use this instance; this enum is guaranteed
736 // to be pretty small so a regular array should be acceptable.
737 namespace std {
738 template <>
739 struct hash<c10::DispatchKey> {
740   typedef size_t result_type;
741   typedef c10::DispatchKey argument_type;
742 
743   size_t operator()(c10::DispatchKey x) const {
744     return static_cast<size_t>(x);
745   }
746 };
747 } // namespace std
748