xref: /aosp_15_r20/external/pytorch/c10/core/DispatchKeySet.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/DispatchKeySet.h>
2 #include <c10/util/irange.h>
3 
4 namespace c10 {
5 
6 // backend_dispatch_keyset includes all dispatch keys that map to backends.
7 // Alias key DispatchKey::CompositeExplicitAutograd maps to
8 // backend_dispatch_keyset
9 constexpr DispatchKeySet backend_dispatch_keyset =
10     autogradother_backends | DispatchKeySet(DispatchKey::Dense);
11 
12 // See Note [CompositeExplicitAutogradNonFunctional Key]
13 // We have several types of decompositions in aten, that each have their own
14 // alias key. You should register your decomposition to the
15 // `CompositeExplicitAutogradNonFunctional key` if: (1) It's an out-of-place op
16 // (2) It decomposes into one more mutation ops
17 // (3) It has a derivative formula
18 //     (In theory we could also have a separate key for
19 //     "CompositeImplicitAutogradNonFunctional", but there isn't much of a use
20 //     case for it currently).
21 // This key is important for "functional" backends like LazyTensor / XLA.
22 // If you're a backend that only expects to deal with "functional ops",
23 // then you don't want to decompose a functional op into an op that causes
24 // aliasing. You should just directly write a kernel for that functional op
25 // instead!
26 constexpr DispatchKeySet non_functional_backend_dispatch_keyset =
27     backend_dispatch_keyset
28         // XLA and LazyTensor are currently the only 2 backends in core
29         // that use functionalization pass in eager mode.
30         .remove(DispatchKey::Sparse)
31         .remove_backend(BackendComponent::XLABit)
32         .remove_backend(BackendComponent::LazyBit);
33 
isBackendDispatchKey(DispatchKey t)34 bool isBackendDispatchKey(DispatchKey t) {
35   return t != DispatchKey::Undefined
36       // See Note [No Alias Keys in DispatchKeySet]
37       && !isAliasDispatchKey(t)
38       // Note [NestedTensor Not Included in Backend Keys]
39       // NestedTensor has been explicitly removed from the "backend keyset" due
40       // to incompatibility with some kernels, so we don't want it to be
41       // included in CompositeExplicitAutograd kernels.
42       && t != DispatchKey::NestedTensor && backend_dispatch_keyset.has(t);
43 }
44 
45 // math_dispatch_keyset contains all keys in backend_dispatch_keyset and
46 // autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd
47 // maps to [math_dispatch_keyset x full_backend_mask]
48 constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
49     autograd_dispatch_keyset |
50     // See Note [NestedTensor Not Included in Backend Keys]
51     // The caveat to that note is that nested_tensor is a special case
52     // where we would like to support composite implicit kernels but not
53     // explicit kernels therefore we manually add the key to the
54     // math_dispatch_keyset
55     DispatchKeySet{DispatchKey::NestedTensor} |
56     // Functionalize should always re-use CompositeImplicit decomps.
57     DispatchKeySet{DispatchKey::Functionalize};
58 
59 constexpr DispatchKeySet nested_dispatch_keyset =
60     DispatchKeySet(
61         {DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) |
62     DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
63 
getRuntimeDispatchKeySet(DispatchKey t)64 DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
65   TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
66   switch (t) {
67     case DispatchKey::Autograd:
68       // See Note [autograd_dispatch_keyset Does Not Include Backend Bits]
69       // That's why we OR it with a mask of the backend bits here.
70       // getRuntimeDispatchKeySet() expects to return a keyset of runtime
71       // dispatch keys, like AutogradCPU, but that requires having backend bits.
72       return autograd_dispatch_keyset |
73           DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
74     case DispatchKey::CompositeImplicitAutograd:
75       return math_dispatch_keyset;
76     case DispatchKey::CompositeImplicitAutogradNestedTensor:
77       return nested_dispatch_keyset;
78     case DispatchKey::CompositeExplicitAutograd:
79       return backend_dispatch_keyset;
80     case DispatchKey::CompositeExplicitAutogradNonFunctional:
81       return non_functional_backend_dispatch_keyset;
82     default:
83       return DispatchKeySet(t);
84   }
85 }
86 
runtimeDispatchKeySetHas(DispatchKey t,DispatchKey k)87 bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) {
88   TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
89   switch (t) {
90     case DispatchKey::Autograd:
91       return autograd_dispatch_keyset.has(toFunctionalityKey(k));
92     case DispatchKey::CompositeImplicitAutograd:
93       // See Note [NestedTensor Not Included in Backend Keys]
94       return math_dispatch_keyset.has(k);
95     case DispatchKey::CompositeImplicitAutogradNestedTensor:
96       // See Note [NestedTensor Not Included in Backend Keys]
97       return nested_dispatch_keyset.has(k);
98     case DispatchKey::CompositeExplicitAutograd:
99       // See Note [NestedTensor Not Included in Backend Keys]
100       return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k);
101     case DispatchKey::CompositeExplicitAutogradNonFunctional:
102       // See Note [NestedTensor Not Included in Backend Keys]
103       return k != DispatchKey::NestedTensor &&
104           non_functional_backend_dispatch_keyset.has(k);
105     case DispatchKey::FuncTorchBatchedDecomposition:
106       return functorch_batched_ks.has(k);
107     default:
108       return t == k;
109   }
110 }
111 
112 // for a given autograd key, return the (guaranteed nonempty) set of associated
113 // backend keys. for a non-autograd key, return the empty keyset.
getBackendKeySetFromAutograd(DispatchKey t)114 DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
115   switch (t) {
116     case DispatchKey::AutogradCPU:
117       return DispatchKeySet(DispatchKey::CPU);
118     case DispatchKey::AutogradCUDA:
119       return DispatchKeySet(DispatchKey::CUDA);
120     case DispatchKey::AutogradXLA:
121       return DispatchKeySet(DispatchKey::XLA);
122     case DispatchKey::AutogradLazy:
123       return DispatchKeySet(DispatchKey::Lazy);
124     case DispatchKey::AutogradMeta:
125       return DispatchKeySet(DispatchKey::Meta);
126     case DispatchKey::AutogradMPS:
127       return DispatchKeySet(DispatchKey::MPS);
128     case DispatchKey::AutogradHPU:
129       return DispatchKeySet(DispatchKey::HPU);
130     case DispatchKey::AutogradIPU:
131       return DispatchKeySet(DispatchKey::IPU);
132     case DispatchKey::AutogradXPU:
133       return DispatchKeySet(DispatchKey::XPU);
134     case DispatchKey::AutogradPrivateUse1:
135       return DispatchKeySet(DispatchKey::PrivateUse1);
136     case DispatchKey::AutogradPrivateUse2:
137       return DispatchKeySet(DispatchKey::PrivateUse2);
138     case DispatchKey::AutogradPrivateUse3:
139       return DispatchKeySet(DispatchKey::PrivateUse3);
140     case DispatchKey::AutogradNestedTensor:
141       return DispatchKeySet(DispatchKey::NestedTensor) |
142           DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
143     case DispatchKey::AutogradOther:
144       return autogradother_backends;
145     default:
146       return DispatchKeySet();
147   }
148 }
149 
isIncludedInAlias(DispatchKey k,DispatchKey alias)150 bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
151   return k != DispatchKey::Undefined && runtimeDispatchKeySetHas(alias, k);
152 }
153 
toString(DispatchKeySet ts)154 std::string toString(DispatchKeySet ts) {
155   std::stringstream ss;
156   ss << ts;
157   return ss.str();
158 }
159 
operator <<(std::ostream & os,DispatchKeySet ts)160 std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
161   if (ts.empty()) {
162     os << "DispatchKeySet()";
163     return os;
164   }
165   os << "DispatchKeySet(";
166   bool first = true;
167   for (auto k : ts) {
168     if (!first) {
169       os << ", ";
170     }
171     os << k;
172     first = false;
173   }
174   os << ")";
175   return os;
176 }
177 
operator ++()178 DispatchKeySet::iterator& DispatchKeySet::iterator::operator++() {
179   TORCH_INTERNAL_ASSERT(next_functionality_ <= iterator::end_iter_mask_val);
180   TORCH_INTERNAL_ASSERT(next_backend_ <= num_backends, next_backend_);
181 
182   // Create a masked version of the set representation to ignore previous
183   // keys that we've iterated through.
184   uint64_t masked_functionality_bits =
185       llvm::maskTrailingZeros<uint64_t>(next_functionality_) & *data_ptr_;
186   uint64_t masked_backend_bits =
187       llvm::maskTrailingZeros<uint64_t>(next_backend_) & full_backend_mask &
188       *data_ptr_;
189 
190   uint64_t first_functionality_idx =
191       llvm::findFirstSet(masked_functionality_bits);
192   uint64_t first_backendcomponent_idx = llvm::findFirstSet(masked_backend_bits);
193 
194   // If there are no keys, set to end iterator value
195   if (first_functionality_idx == std::numeric_limits<uint64_t>::max() ||
196       next_functionality_ == iterator::end_iter_mask_val) {
197     // Set up state to be the same as end()
198     next_functionality_ = iterator::end_iter_mask_val;
199     current_dispatchkey_idx_ = iterator::end_iter_key_val;
200     next_backend_ = 0;
201     current_backendcomponent_idx_ = iterator::end_iter_key_val;
202     return *this;
203   }
204 
205   // The +1 is because of DispatchKey::Undefined and
206   // BackendComponent::InvalidBit
207   auto new_next_functionality = first_functionality_idx + 1;
208   auto new_backendcomponent_idx = first_backendcomponent_idx + 1;
209   // and the -num_backends is because the first <num_backends> bits in the
210   // keyset are not Dispatch Keys.
211   auto next_dispatchkey_idx = new_next_functionality - num_backends;
212 
213   // If the current functionality bit is a per-backend bit, we need special
214   // handling
215   if (isPerBackendFunctionalityKey(
216           static_cast<DispatchKey>(next_dispatchkey_idx))) {
217     // case 1: if the current backend is undefined, then there is no valid
218     // backend instance of this functionality key so we can skip it.
219     if (first_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
220       // increment the functionality mask so we skip the current functionality
221       // bit on the next increment.
222       next_functionality_ = new_next_functionality;
223       ++(*this);
224       return *this;
225     }
226 
227     // Otherwise, at this point we know what the current backend and
228     // functionality bits are.
229     current_dispatchkey_idx_ = next_dispatchkey_idx;
230     current_backendcomponent_idx_ = new_backendcomponent_idx;
231 
232     // Next, we need to set up the masks for the next increment.
233     uint64_t next_backendcomponent_bits =
234         llvm::maskTrailingZeros<uint64_t>(first_backendcomponent_idx + 1) &
235         full_backend_mask & *data_ptr_;
236     uint64_t next_backendcomponent_idx =
237         llvm::findFirstSet(next_backendcomponent_bits);
238     if (next_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
239       // case 2: the current backend is valid, but there is not another backend
240       // in the keyset. In this case, we need to bump the functionality mask and
241       // reset the backend mask for the next increment
242       next_functionality_ = new_next_functionality;
243       next_backend_ = 0;
244     } else {
245       // case 3: we have another backend to iterate over. We want to iterate
246       // over the same functionality bit next time, but a different backend bit.
247       next_backend_ = first_backendcomponent_idx + 1;
248     }
249   } else {
250     // Functionality bits that aren't per backend are simpler to handle. We can
251     // ignore the backend bits.
252     TORCH_INTERNAL_ASSERT(next_backend_ == 0);
253     current_dispatchkey_idx_ = next_dispatchkey_idx;
254     next_functionality_ = new_next_functionality;
255   }
256   return *this;
257 }
258 
259 std::array<FunctionalityOffsetAndMask, num_functionality_keys>
initializeFunctionalityOffsetsAndMasks()260 initializeFunctionalityOffsetsAndMasks() {
261   std::array<FunctionalityOffsetAndMask, num_functionality_keys>
262       offsets_and_masks;
263   // manually set the first entry, which corresponds to Undefined.
264   offsets_and_masks[0] = FunctionalityOffsetAndMask(0, 0);
265   // loop through every functionality key (aside from Undefined).
266   for (const auto functionality_idx : c10::irange(1, num_functionality_keys)) {
267     // functionality_idx should be Dense -> 1, ...
268     auto prev_offset_and_mask = offsets_and_masks[functionality_idx - 1];
269     auto k = static_cast<DispatchKey>(functionality_idx);
270 
271     // If the previous functionality was not per-backend, then we can just
272     // increment the previous offset. Otherwise, the next offset =
273     // previous_offset + num_backends.
274     auto next_offset = prev_offset_and_mask.offset +
275         (prev_offset_and_mask.mask == 0 ? 1 : num_backends);
276     // the mask is used in the runtime index calculation to find the offset of
277     // the backend. For non-per-backend functionalities, this offset should
278     // always be 0. Otherwise, we need to get the index of the backend (which we
279     // can do using a backend mask).
280     auto next_mask = isPerBackendFunctionalityKey(k) ? full_backend_mask : 0;
281     offsets_and_masks[functionality_idx] =
282         FunctionalityOffsetAndMask(next_offset, next_mask);
283   }
284   // Sanity check that the computed offset index of the last functionality key
285   // is correct. This assumes that the highest priority functionality key is not
286   // per backend.
287   TORCH_INTERNAL_ASSERT(
288       offsets_and_masks[num_functionality_keys - 1].offset ==
289           (num_runtime_entries - 1),
290       "num_runtime_entries: ",
291       num_runtime_entries,
292       "last_offset: ",
293       offsets_and_masks[num_functionality_keys - 1].offset);
294   return offsets_and_masks;
295 }
296 
297 } // namespace c10
298