1 #include <c10/core/DispatchKeySet.h>
2 #include <c10/util/irange.h>
3
4 namespace c10 {
5
6 // backend_dispatch_keyset includes all dispatch keys that map to backends.
7 // Alias key DispatchKey::CompositeExplicitAutograd maps to
8 // backend_dispatch_keyset
9 constexpr DispatchKeySet backend_dispatch_keyset =
10 autogradother_backends | DispatchKeySet(DispatchKey::Dense);
11
12 // See Note [CompositeExplicitAutogradNonFunctional Key]
13 // We have several types of decompositions in aten, that each have their own
14 // alias key. You should register your decomposition to the
15 // `CompositeExplicitAutogradNonFunctional key` if: (1) It's an out-of-place op
16 // (2) It decomposes into one more mutation ops
17 // (3) It has a derivative formula
18 // (In theory we could also have a separate key for
19 // "CompositeImplicitAutogradNonFunctional", but there isn't much of a use
20 // case for it currently).
21 // This key is important for "functional" backends like LazyTensor / XLA.
22 // If you're a backend that only expects to deal with "functional ops",
23 // then you don't want to decompose a functional op into an op that causes
24 // aliasing. You should just directly write a kernel for that functional op
25 // instead!
26 constexpr DispatchKeySet non_functional_backend_dispatch_keyset =
27 backend_dispatch_keyset
28 // XLA and LazyTensor are currently the only 2 backends in core
29 // that use functionalization pass in eager mode.
30 .remove(DispatchKey::Sparse)
31 .remove_backend(BackendComponent::XLABit)
32 .remove_backend(BackendComponent::LazyBit);
33
isBackendDispatchKey(DispatchKey t)34 bool isBackendDispatchKey(DispatchKey t) {
35 return t != DispatchKey::Undefined
36 // See Note [No Alias Keys in DispatchKeySet]
37 && !isAliasDispatchKey(t)
38 // Note [NestedTensor Not Included in Backend Keys]
39 // NestedTensor has been explicitly removed from the "backend keyset" due
40 // to incompatibility with some kernels, so we don't want it to be
41 // included in CompositeExplicitAutograd kernels.
42 && t != DispatchKey::NestedTensor && backend_dispatch_keyset.has(t);
43 }
44
45 // math_dispatch_keyset contains all keys in backend_dispatch_keyset and
46 // autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd
47 // maps to [math_dispatch_keyset x full_backend_mask]
48 constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
49 autograd_dispatch_keyset |
50 // See Note [NestedTensor Not Included in Backend Keys]
51 // The caveat to that note is that nested_tensor is a special case
52 // where we would like to support composite implicit kernels but not
53 // explicit kernels therefore we manually add the key to the
54 // math_dispatch_keyset
55 DispatchKeySet{DispatchKey::NestedTensor} |
56 // Functionalize should always re-use CompositeImplicit decomps.
57 DispatchKeySet{DispatchKey::Functionalize};
58
59 constexpr DispatchKeySet nested_dispatch_keyset =
60 DispatchKeySet(
61 {DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) |
62 DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
63
getRuntimeDispatchKeySet(DispatchKey t)64 DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
65 TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
66 switch (t) {
67 case DispatchKey::Autograd:
68 // See Note [autograd_dispatch_keyset Does Not Include Backend Bits]
69 // That's why we OR it with a mask of the backend bits here.
70 // getRuntimeDispatchKeySet() expects to return a keyset of runtime
71 // dispatch keys, like AutogradCPU, but that requires having backend bits.
72 return autograd_dispatch_keyset |
73 DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
74 case DispatchKey::CompositeImplicitAutograd:
75 return math_dispatch_keyset;
76 case DispatchKey::CompositeImplicitAutogradNestedTensor:
77 return nested_dispatch_keyset;
78 case DispatchKey::CompositeExplicitAutograd:
79 return backend_dispatch_keyset;
80 case DispatchKey::CompositeExplicitAutogradNonFunctional:
81 return non_functional_backend_dispatch_keyset;
82 default:
83 return DispatchKeySet(t);
84 }
85 }
86
runtimeDispatchKeySetHas(DispatchKey t,DispatchKey k)87 bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) {
88 TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
89 switch (t) {
90 case DispatchKey::Autograd:
91 return autograd_dispatch_keyset.has(toFunctionalityKey(k));
92 case DispatchKey::CompositeImplicitAutograd:
93 // See Note [NestedTensor Not Included in Backend Keys]
94 return math_dispatch_keyset.has(k);
95 case DispatchKey::CompositeImplicitAutogradNestedTensor:
96 // See Note [NestedTensor Not Included in Backend Keys]
97 return nested_dispatch_keyset.has(k);
98 case DispatchKey::CompositeExplicitAutograd:
99 // See Note [NestedTensor Not Included in Backend Keys]
100 return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k);
101 case DispatchKey::CompositeExplicitAutogradNonFunctional:
102 // See Note [NestedTensor Not Included in Backend Keys]
103 return k != DispatchKey::NestedTensor &&
104 non_functional_backend_dispatch_keyset.has(k);
105 case DispatchKey::FuncTorchBatchedDecomposition:
106 return functorch_batched_ks.has(k);
107 default:
108 return t == k;
109 }
110 }
111
112 // for a given autograd key, return the (guaranteed nonempty) set of associated
113 // backend keys. for a non-autograd key, return the empty keyset.
getBackendKeySetFromAutograd(DispatchKey t)114 DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
115 switch (t) {
116 case DispatchKey::AutogradCPU:
117 return DispatchKeySet(DispatchKey::CPU);
118 case DispatchKey::AutogradCUDA:
119 return DispatchKeySet(DispatchKey::CUDA);
120 case DispatchKey::AutogradXLA:
121 return DispatchKeySet(DispatchKey::XLA);
122 case DispatchKey::AutogradLazy:
123 return DispatchKeySet(DispatchKey::Lazy);
124 case DispatchKey::AutogradMeta:
125 return DispatchKeySet(DispatchKey::Meta);
126 case DispatchKey::AutogradMPS:
127 return DispatchKeySet(DispatchKey::MPS);
128 case DispatchKey::AutogradHPU:
129 return DispatchKeySet(DispatchKey::HPU);
130 case DispatchKey::AutogradIPU:
131 return DispatchKeySet(DispatchKey::IPU);
132 case DispatchKey::AutogradXPU:
133 return DispatchKeySet(DispatchKey::XPU);
134 case DispatchKey::AutogradPrivateUse1:
135 return DispatchKeySet(DispatchKey::PrivateUse1);
136 case DispatchKey::AutogradPrivateUse2:
137 return DispatchKeySet(DispatchKey::PrivateUse2);
138 case DispatchKey::AutogradPrivateUse3:
139 return DispatchKeySet(DispatchKey::PrivateUse3);
140 case DispatchKey::AutogradNestedTensor:
141 return DispatchKeySet(DispatchKey::NestedTensor) |
142 DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
143 case DispatchKey::AutogradOther:
144 return autogradother_backends;
145 default:
146 return DispatchKeySet();
147 }
148 }
149
isIncludedInAlias(DispatchKey k,DispatchKey alias)150 bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
151 return k != DispatchKey::Undefined && runtimeDispatchKeySetHas(alias, k);
152 }
153
toString(DispatchKeySet ts)154 std::string toString(DispatchKeySet ts) {
155 std::stringstream ss;
156 ss << ts;
157 return ss.str();
158 }
159
operator <<(std::ostream & os,DispatchKeySet ts)160 std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
161 if (ts.empty()) {
162 os << "DispatchKeySet()";
163 return os;
164 }
165 os << "DispatchKeySet(";
166 bool first = true;
167 for (auto k : ts) {
168 if (!first) {
169 os << ", ";
170 }
171 os << k;
172 first = false;
173 }
174 os << ")";
175 return os;
176 }
177
operator ++()178 DispatchKeySet::iterator& DispatchKeySet::iterator::operator++() {
179 TORCH_INTERNAL_ASSERT(next_functionality_ <= iterator::end_iter_mask_val);
180 TORCH_INTERNAL_ASSERT(next_backend_ <= num_backends, next_backend_);
181
182 // Create a masked version of the set representation to ignore previous
183 // keys that we've iterated through.
184 uint64_t masked_functionality_bits =
185 llvm::maskTrailingZeros<uint64_t>(next_functionality_) & *data_ptr_;
186 uint64_t masked_backend_bits =
187 llvm::maskTrailingZeros<uint64_t>(next_backend_) & full_backend_mask &
188 *data_ptr_;
189
190 uint64_t first_functionality_idx =
191 llvm::findFirstSet(masked_functionality_bits);
192 uint64_t first_backendcomponent_idx = llvm::findFirstSet(masked_backend_bits);
193
194 // If there are no keys, set to end iterator value
195 if (first_functionality_idx == std::numeric_limits<uint64_t>::max() ||
196 next_functionality_ == iterator::end_iter_mask_val) {
197 // Set up state to be the same as end()
198 next_functionality_ = iterator::end_iter_mask_val;
199 current_dispatchkey_idx_ = iterator::end_iter_key_val;
200 next_backend_ = 0;
201 current_backendcomponent_idx_ = iterator::end_iter_key_val;
202 return *this;
203 }
204
205 // The +1 is because of DispatchKey::Undefined and
206 // BackendComponent::InvalidBit
207 auto new_next_functionality = first_functionality_idx + 1;
208 auto new_backendcomponent_idx = first_backendcomponent_idx + 1;
209 // and the -num_backends is because the first <num_backends> bits in the
210 // keyset are not Dispatch Keys.
211 auto next_dispatchkey_idx = new_next_functionality - num_backends;
212
213 // If the current functionality bit is a per-backend bit, we need special
214 // handling
215 if (isPerBackendFunctionalityKey(
216 static_cast<DispatchKey>(next_dispatchkey_idx))) {
217 // case 1: if the current backend is undefined, then there is no valid
218 // backend instance of this functionality key so we can skip it.
219 if (first_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
220 // increment the functionality mask so we skip the current functionality
221 // bit on the next increment.
222 next_functionality_ = new_next_functionality;
223 ++(*this);
224 return *this;
225 }
226
227 // Otherwise, at this point we know what the current backend and
228 // functionality bits are.
229 current_dispatchkey_idx_ = next_dispatchkey_idx;
230 current_backendcomponent_idx_ = new_backendcomponent_idx;
231
232 // Next, we need to set up the masks for the next increment.
233 uint64_t next_backendcomponent_bits =
234 llvm::maskTrailingZeros<uint64_t>(first_backendcomponent_idx + 1) &
235 full_backend_mask & *data_ptr_;
236 uint64_t next_backendcomponent_idx =
237 llvm::findFirstSet(next_backendcomponent_bits);
238 if (next_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
239 // case 2: the current backend is valid, but there is not another backend
240 // in the keyset. In this case, we need to bump the functionality mask and
241 // reset the backend mask for the next increment
242 next_functionality_ = new_next_functionality;
243 next_backend_ = 0;
244 } else {
245 // case 3: we have another backend to iterate over. We want to iterate
246 // over the same functionality bit next time, but a different backend bit.
247 next_backend_ = first_backendcomponent_idx + 1;
248 }
249 } else {
250 // Functionality bits that aren't per backend are simpler to handle. We can
251 // ignore the backend bits.
252 TORCH_INTERNAL_ASSERT(next_backend_ == 0);
253 current_dispatchkey_idx_ = next_dispatchkey_idx;
254 next_functionality_ = new_next_functionality;
255 }
256 return *this;
257 }
258
259 std::array<FunctionalityOffsetAndMask, num_functionality_keys>
initializeFunctionalityOffsetsAndMasks()260 initializeFunctionalityOffsetsAndMasks() {
261 std::array<FunctionalityOffsetAndMask, num_functionality_keys>
262 offsets_and_masks;
263 // manually set the first entry, which corresponds to Undefined.
264 offsets_and_masks[0] = FunctionalityOffsetAndMask(0, 0);
265 // loop through every functionality key (aside from Undefined).
266 for (const auto functionality_idx : c10::irange(1, num_functionality_keys)) {
267 // functionality_idx should be Dense -> 1, ...
268 auto prev_offset_and_mask = offsets_and_masks[functionality_idx - 1];
269 auto k = static_cast<DispatchKey>(functionality_idx);
270
271 // If the previous functionality was not per-backend, then we can just
272 // increment the previous offset. Otherwise, the next offset =
273 // previous_offset + num_backends.
274 auto next_offset = prev_offset_and_mask.offset +
275 (prev_offset_and_mask.mask == 0 ? 1 : num_backends);
276 // the mask is used in the runtime index calculation to find the offset of
277 // the backend. For non-per-backend functionalities, this offset should
278 // always be 0. Otherwise, we need to get the index of the backend (which we
279 // can do using a backend mask).
280 auto next_mask = isPerBackendFunctionalityKey(k) ? full_backend_mask : 0;
281 offsets_and_masks[functionality_idx] =
282 FunctionalityOffsetAndMask(next_offset, next_mask);
283 }
284 // Sanity check that the computed offset index of the last functionality key
285 // is correct. This assumes that the highest priority functionality key is not
286 // per backend.
287 TORCH_INTERNAL_ASSERT(
288 offsets_and_masks[num_functionality_keys - 1].offset ==
289 (num_runtime_entries - 1),
290 "num_runtime_entries: ",
291 num_runtime_entries,
292 "last_offset: ",
293 offsets_and_masks[num_functionality_keys - 1].offset);
294 return offsets_and_masks;
295 }
296
297 } // namespace c10
298