1 #include <ATen/core/dispatch/OperatorEntry.h>
2 #include <ATen/core/op_registration/infer_schema.h>
3 #include <ATen/core/dispatch/Dispatcher.h>
4 #include <ATen/core/dispatch/ObservedOperators.h>
5
6 namespace c10 {
7 namespace impl {
8
9 namespace {
10 #ifndef STRIP_ERROR_MESSAGES
toString(std::optional<DispatchKey> k)11 std::string toString(std::optional<DispatchKey> k) {
12 if (k.has_value()) {
13 return toString(*k);
14 } else {
15 return "(catch all)";
16 }
17 }
18 #endif
19 }
20
OperatorEntry(OperatorName && operator_name)21 OperatorEntry::OperatorEntry(OperatorName&& operator_name)
22 : name_(std::move(operator_name))
23 , schema_()
24 #ifndef C10_MOBILE
25 , tags_()
26 #endif
27 , dispatchTable_()
28 , dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
29 , kernels_()
30 , cpp_signature_()
31 , sym_cpp_signature_()
32 , is_observed_(ObservedOperators::isObserved(name_))
33 {
34 // Pick up any backend fallbacks that were registered prior to this
35 // OperatorEntry being created
36 updateDispatchTableFull_(c10::Dispatcher::singleton());
37 }
38
39 namespace {
checkSchema(const OperatorName & name,const FunctionSchema & from_def_,const std::string & from_def_debug,const KernelFunction & kernel,const FunctionSchema & inferred_,const std::string & inferred_debug)40 void checkSchema(const OperatorName& name, const FunctionSchema& from_def_, const std::string& from_def_debug, const KernelFunction& kernel, const FunctionSchema& inferred_, const std::string& inferred_debug) {
41 // TODO: figure out if we can just directly save real schema at def time
42 FunctionSchema from_def = from_def_.cloneWithRealTypes(kernel.isValidSymUnboxed());
43 FunctionSchema inferred = inferred_.cloneWithRealTypes();
44 std::optional<std::string> schema_difference = findSchemaDifferences(from_def, inferred);
45 if (schema_difference.has_value()) {
46 TORCH_CHECK(false,
47 "Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n"
48 " operator: ", toString(name), "\n",
49 " expected schema: ", toString(from_def), "\n",
50 " ", from_def_debug, "\n",
51 " inferred schema: ", toString(inferred), "\n",
52 " ", inferred_debug, "\n",
53 " reason: ", *schema_difference);
54 }
55 }
56 } // anonymous namespace
57
missingKernel() const58 const AnnotatedKernel& OperatorEntry::missingKernel() const {
59 static AnnotatedKernel kernel;
60 return kernel;
61 }
62
ambiguousAutogradOtherKernel() const63 const AnnotatedKernel& OperatorEntry::ambiguousAutogradOtherKernel() const {
64 static AnnotatedKernel kernel(
65 c10::KernelFunction::makeAmbiguousAutogradOther(), nullptr, "ambiguous_autogradother");
66 return kernel;
67 }
68
assertSignatureIsCorrect(const CppSignature & call_signature,bool has_symint) const69 void OperatorEntry::assertSignatureIsCorrect(const CppSignature& call_signature, bool has_symint) const {
70 if (has_symint) {
71 if (C10_UNLIKELY(sym_cpp_signature_.has_value() && (call_signature != sym_cpp_signature_->signature))) {
72 reportSignatureError(call_signature, *sym_cpp_signature_);
73 }
74 } else {
75 if (C10_UNLIKELY(cpp_signature_.has_value() && (call_signature != cpp_signature_->signature))) {
76 reportSignatureError(call_signature, *cpp_signature_);
77 }
78 }
79 }
80
registerSchema(FunctionSchema && schema,std::string && debug,std::vector<at::Tag> tags)81 void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug, std::vector<at::Tag> tags) {
82 TORCH_INTERNAL_ASSERT(!schema_.has_value());
83 for (const auto& kernel : kernels_) {
84 for (const auto &j : kernel.second) {
85 if (j.inferred_function_schema != nullptr) {
86 checkSchema(name_, schema, debug, j.kernel, *j.inferred_function_schema, j.debug);
87 }
88 }
89 }
90 // NB: don't register schema until after we've checked everything!
91 dispatchKeyExtractor_.registerSchema(schema);
92 schema_ = AnnotatedSchema(std::move(schema), std::move(debug));
93 #ifndef C10_MOBILE
94 tags_ = std::move(tags);
95 #endif
96 }
97
deregisterSchema()98 void OperatorEntry::deregisterSchema() {
99 TORCH_INTERNAL_ASSERT(schema_.has_value());
100 schema_ = std::nullopt;
101 dispatchKeyExtractor_.deregisterSchema();
102 }
103
registerKernel(const c10::Dispatcher & dispatcher,std::optional<DispatchKey> dispatch_key,KernelFunction kernel,std::optional<CppSignature> cpp_signature,std::unique_ptr<FunctionSchema> inferred_function_schema,std::string debug)104 OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
105 const c10::Dispatcher& dispatcher,
106 std::optional<DispatchKey> dispatch_key,
107 KernelFunction kernel,
108 std::optional<CppSignature> cpp_signature,
109 std::unique_ptr<FunctionSchema> inferred_function_schema,
110 std::string debug
111 ) {
112 // NB: cpp_signature doesn't get cleared even after the kernel that populated
113 // it is deleted. This means you could poison the value of cpp_signature_
114 // with a bad signature value, and then it would permanently stay there until
115 // you deregister the schema. This can't really be fixed, because we
116 // only do a typed() test once in the lifetime of a TypedOperatorHandle,
117 // which means if you could validly change the type of a cpp_signature, then
118 // that would also invalidate the old TypedOperatorHandles.
119 if (cpp_signature.has_value()) {
120 auto& local_cpp_signature = kernel.isValidSymUnboxed() ? sym_cpp_signature_ : cpp_signature_;
121 if (local_cpp_signature.has_value()) {
122 TORCH_CHECK(*cpp_signature == local_cpp_signature->signature,
123 "\nMismatch in kernel C++ signatures\n",
124 " operator: ", (this->schema_.has_value() ? toString(this->schema_->schema) : toString(name_)), "\n",
125 " ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
126 " kernel 1: ", local_cpp_signature->signature.name(), "\n",
127 " dispatch key: ", toString(local_cpp_signature->dispatch_key), "\n",
128 " ", local_cpp_signature->debug, "\n",
129 " kernel 2: ", cpp_signature->name(), "\n",
130 " dispatch key: ", toString(dispatch_key), "\n",
131 " ", debug, "\n"
132 );
133 } else {
134 local_cpp_signature = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key };
135 }
136 }
137
138 if (schema_ && inferred_function_schema) {
139 checkSchema(name_, schema_->schema, schema_->debug, kernel, *inferred_function_schema, debug);
140 }
141
142 // Add the kernel to the kernels list,
143 // possibly creating the list if this is the first kernel.
144 // Redirect catchAll registrations to CompositeImplicitAutograd.
145 auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::CompositeImplicitAutograd];
146
147 #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
148 if (k[0].kernel.isValid()) {
149 #else
150 if (!k.empty()) {
151 #endif
152 // Suppress the warning for Meta key as we are overriding C++ meta functions with python meta functions
153 // for some ops
154 if (dispatch_key != DispatchKey::Meta) {
155 TORCH_WARN_ONCE("Warning only once for all operators, other operators may also be overridden.\n",
156 " Overriding a previously registered kernel for the same operator and the same dispatch key\n",
157 " operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
158 " ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
159 " dispatch key: ", toString(dispatch_key), "\n",
160 " previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : (sym_cpp_signature_.has_value() ? sym_cpp_signature_->debug : "no debug info")), "\n",
161 " new kernel: ", debug
162 );
163 }
164 }
165
166 #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
167 k[0].kernel = std::move(kernel);
168 k[0].inferred_function_schema = std::move(inferred_function_schema);
169 k[0].debug = std::move(debug);
170 #else
171 k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug));
172 #endif
173 AnnotatedKernelContainerIterator inserted = k.begin();
174 // update the dispatch table, i.e. re-establish the invariant
175 // that the dispatch table points to the newest kernel
176 if (dispatch_key.has_value()) {
177 updateDispatchTable_(dispatcher, *dispatch_key);
178 } else {
179 updateDispatchTableFull_(dispatcher);
180 }
181 return inserted;
182 }
183
184 void OperatorEntry::deregisterKernel_(
185 const c10::Dispatcher& dispatcher,
186 std::optional<DispatchKey> dispatch_key,
187 AnnotatedKernelContainerIterator kernel
188 ) {
189 // Redirect catchAll deregistrations to CompositeImplicitAutograd.
190 DispatchKey dk = dispatch_key.has_value() ? *dispatch_key : DispatchKey::CompositeImplicitAutograd;
191 auto found = kernels_.find(dk);
192 TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_));
193 auto& k = found->second;
194 #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
195 // We are about to remove the array from the map, no need to do anything.
196 #else
197 k.erase(kernel);
198 #endif
199 if (k.empty()) {
200 // the invariant says we don't want empty lists but instead remove the list from the map
201 kernels_.erase(found);
202 }
203 updateDispatchTable_(dispatcher, dk);
204 }
205
206 void OperatorEntry::updateFallback(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
207 updateDispatchTable_(dispatcher, dispatch_key);
208 }
209
210 const KernelFunction& OperatorEntry::computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const {
211 return computeDispatchTableEntryWithDebug(dispatcher, dispatch_key).first.kernel;
212 }
213
214 bool OperatorEntry::hasKernelForAnyDispatchKey(DispatchKeySet ks) const {
215 TORCH_INTERNAL_ASSERT(kernels_.find(DispatchKey::Undefined) == kernels_.end());
216 for (auto& kv : kernels_) {
217 // Note [No Alias Keys in DispatchKeySet]
218 if (!isAliasDispatchKey(kv.first) && ks.has(kv.first)) return true;
219 }
220 return false;
221 }
222
223 bool OperatorEntry::hasKernelForDispatchKey(DispatchKey k) const {
224 TORCH_INTERNAL_ASSERT(kernels_.find(DispatchKey::Undefined) == kernels_.end());
225 auto it = kernels_.find(k);
226 if (it == kernels_.end()) return false;
227 return !it->second.empty();
228 }
229
230 const KernelFunction& OperatorEntry::kernelForDispatchKey(DispatchKey k) const {
231 auto it = kernels_.find(k);
232 TORCH_CHECK(it != kernels_.end() && !it->second.empty(), "no kernel for ", k, " on ", name_);
233 auto jt = it->second.begin();
234 TORCH_INTERNAL_ASSERT(jt->kernel.isValid())
235 return jt->kernel;
236 }
237
238 bool OperatorEntry::hasComputedKernelForDispatchKey(DispatchKey k) const {
239 TORCH_CHECK(!isAliasDispatchKey(k), "Alias keys do not have runtime kernel registrations.");
240 const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
241 TORCH_INTERNAL_ASSERT(dispatch_ix >= 0 && dispatch_ix < c10::num_runtime_entries, toString(k), dispatch_ix);
242 return dispatchTable_[dispatch_ix].isValid();
243 }
244
245 const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{
246 auto kern_it = kernels_.find(dispatch_key);
247 if (kern_it != kernels_.end()) {
248 TORCH_INTERNAL_ASSERT(!kern_it->second.empty());
249 TORCH_INTERNAL_ASSERT(kern_it->second.front().kernel.isValid());
250 return &kern_it->second.front();
251 }
252 return nullptr;
253 }
254
255 const std::vector<at::Tag>& OperatorEntry::getTags() const {
256 #if defined C10_MOBILE
257 TORCH_CHECK(false, "tags are not saved for Mobile");
258 #else
259 return tags_;
260 #endif
261 }
262
263 std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTableEntryWithDebug(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const {
264 // [Note] DispatchTable computation
265 // dispatchTable contains entries for runtime dispatch keys.
266 // For any dispatch key, it'll pick a kernel using the following order:
267 // (1) Use kernel if it's directly registered to this key
268 // (2) Handle runtime keys that have kernels available from alias keys
269 // (2.1) Use kernel from DispatchKey::CompositeExplicitAutogradNonFunctional if available.
270 // This is used to register a kernel that works for all backends in inference, except "functional" backends
271 // like LazyTensor/XLA. But it requires separate registration for Autograd keys to support training.
272 // (2.2) Use kernel from DispatchKey::CompositeExplicitAutograd if available.
273 // This is used to register a kernel that works for all backend in inference. But it requires
274 // separate registration for Autograd keys to support training.
275 // (2.3) Use kernel from DispatchKey::CompositeImplicitAutograd if available.
276 // For autograd keys, we only use kernel from CompositeImplicitAutograd when there's no direct registration
277 // to its corresponding backend key or CompositeExplicitAutograd. See Note [CompositeExplicitAutograd and CompositeImplicitAutograd].
278 // For AutogradOther, we eagerly return ambiguousAutogradOtherKernel() if there's registration to any of
279 // its backends and ask backend extender to request a decicated Autograd key for the backend.
280 // See Note [Ambiguity in AutogradOther kernel] for more details.
281 // A CompositeExplicitAutograd kernel prevents CompositeImplicitAutograd kernel being used for Autograd keys, but it doesn't
282 // cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available)
283 // in this case.
284 // (2.4) Use kernel from DispatchKey::Autograd if available
285 // (2.5) Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available
286 // The implementation of (2.2) relies on the invariant that for a given backend,
287 // `computeDispatchTableEntryWithDebug()` will be called for that backend's autograd key after the
288 // backend key. See Note [Refresh Runtime Autograd entries in dispatchTable_]
289 // (3) Use fallthrough kernel that are registered as fallback.
290 // Alias Key Precedence:
291 // CompositExplicitAutogradNonFunctional > CompositeExplicitAutograd > CompositeImplicitAutograd > Autograd
292 // Note [CompositeExplicitAutograd and CompositeImplicitAutograd]
293 // When there're registrations to both CompositeExplicitAutograd & CompositeImplicitAutograd & Autograd, from (2.2) we know CompositeExplicitAutograd
294 // and Autograd kernels will be picked up and CompositeImplicitAutograd is overriden.
295 // This is fine and in practice CompositeExplicitAutograd and CompositeImplicitAutograd shouldn't co-exist for an op.
296 // TODO: Update alias key precedence after we add new alias keys AutogradDispatchCPUOrCUDA .
297
298 // 1. Operator registration
299 if (auto direct_registration = getKernelForDispatchKey(dispatch_key)) {
300 return {*direct_registration, "kernel"};
301 }
302
303 // 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available.
304 // See Note [Undefined in dispatchTable_] for the special handling for Undefined.
305 if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeExplicitAutogradNonFunctional)) {
306 if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::CompositeExplicitAutogradNonFunctional)) {
307 return {*default_backend_registration, "default backend kernel"};
308 }
309 }
310
311 // 2.2 Use CompositeExplicitAutograd kernel if available.
312 // See Note [Undefined in dispatchTable_] for the special handling for Undefined.
313 if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeExplicitAutograd)) {
314 if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::CompositeExplicitAutograd)) {
315 return {*default_backend_registration, "default backend kernel"};
316 }
317 }
318
319 // Note when there's direct registration to CompositeExplicitAutograd, this code path will only be hit by
320 // non backend keys (e.g AutogradXXX, Batched etc) due to (2.1).
321 bool has_backend_kernel =
322 hasKernelForAnyDispatchKey(getBackendKeySetFromAutograd(dispatch_key)) ||
323 // See Note [No Alias Keys in DispatchKeySet]
324 hasKernelForDispatchKey(DispatchKey::CompositeExplicitAutograd);
325
326 // 2.3. Use CompositeImplicitAutograd kernel if available. For autograd keys, we only use kernel from CompositeImplicitAutograd
327 // when there's no direct registration to its corresponding backend key or CompositeExplicitAutograd.
328 // For AutogradOther, we return ambiguousAutogradOtherKernel() if there's registration
329 // to any of its backends.
330 // See Note [Undefined in dispatchTable_] for the special handling for Undefined.
331
332 // If the dispatch key is included in CompositeImplicitAutogradNestedTensor,
333 // then we register it to nested-tensor kernel rather than
334 // regular-tensor CompositeImplicitAutograd kernel.
335 // We have no intention to change the behavior of Undefined,
336 // so this nested-tensor branch requires `dispatch_key != DispatchKey::Undefined`
337 // to let the original CompositeImplicitAutograd handle Undefined
338 // See Note: [Disjoint AliasKeyset] The order for this alias key doesn't matter
339 if (dispatch_key != DispatchKey::Undefined && isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutogradNestedTensor)) {
340 if (auto nested_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutogradNestedTensor)) {
341 return {*nested_registration, "nested kernel"};
342 }
343 }
344
345 if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutograd)) {
346 if (auto math_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutograd)) {
347 if (dispatch_key == DispatchKey::AutogradOther
348 && hasKernelForAnyDispatchKey(c10::autogradother_backends)) {
349 return {ambiguousAutogradOtherKernel(), "ambiguous autogradother"};
350 } else if (!has_backend_kernel) {
351 return {*math_registration, "math kernel"};
352 }
353 }
354 }
355
356 // 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
357 if (isIncludedInAlias(dispatch_key, DispatchKey::Autograd)) {
358 if (auto autograd_registration = getKernelForDispatchKey(DispatchKey::Autograd)) {
359 return {*autograd_registration, "autograd kernel"};
360 }
361 }
362
363 // 2.5. For batched backend keys, use kernel from DispatchKey::FuncTorchBatchedDecomposition if available
364 // See Note: [Disjoint AliasKeyset] The order for this alias key doesn't matter
365 if (isIncludedInAlias(dispatch_key, DispatchKey::FuncTorchBatchedDecomposition)) {
366 if (auto batched_registration = getKernelForDispatchKey(DispatchKey::FuncTorchBatchedDecomposition)) {
367 return {*batched_registration, "batched kernel"};
368 }
369 }
370
371 // 3. Backend fallback
372 auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
373 if (dispatch_ix < 0) {
374 return {missingKernel(), "backend fallback not registered on mobile"};
375 }
376 if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) {
377 return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"};
378 }
379
380 // 4. Default to error
381 return {missingKernel(), "missing"};
382 }
383
384 // synchronizes the dispatch table entry for a given dispatch key
385 // with the current state of kernel registrations in the dispatcher.
386 // note that this is not a complete update, due to relationships between
387 // dispatch keys (e.g. runtime keys and their associated autograd keys,
388 // or alias keys and their associated keysets).
389 // This function should be considered a private helper for updateDispatchTable_()
390 void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
391 const auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
392 if (C10_UNLIKELY(dispatch_ix == -1)) {
393 return;
394 }
395 dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
396 dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
397 }
398
399 // synchronizes the dispatch table entries for a given dispatch key *and its
400 // associated keys* with the current state of kernel registrations in the
401 // dispatcher.
402 // After a kernel has been registered to a dispatch key, a call to this
403 // function will synchronize the dispatcher state. See e.g. registerKernel()
404 void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
405 // Handle Undefined separately since it isn't a runtime key but we have an entry in dispatchTable_.
406 // See Note [Undefined in dispatchTable_]
407 if (dispatch_key == DispatchKey::Undefined) {
408 updateDispatchTableEntry_(dispatcher, dispatch_key);
409 return;
410 }
411 for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) {
412 updateDispatchTableEntry_(dispatcher, k);
413 }
414 // Registration to CompositeExplicitAutogradNonFunctional, CompositeExplicitAutograd and CompositeImplicitAutograd should be populated to Undefined.
415 // We cannot do this above since Undefined cannot be represented in DispatchKeySet.
416 if (dispatch_key == DispatchKey::CompositeImplicitAutograd
417 || dispatch_key == DispatchKey::CompositeExplicitAutograd
418 || dispatch_key == DispatchKey::CompositeExplicitAutogradNonFunctional) {
419 updateDispatchTableEntry_(dispatcher, DispatchKey::Undefined);
420 }
421 // Note [Refresh Runtime Autograd entries in dispatchTable_]
422 // Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
423 // In theory, we should only have to check if the given runtime key has "dense" functionality,
424 // e.g. DispatchKey::CPU (which is composed of DispatchKey::Dense and BackendComponent::CPUBit).
425 // However, there are some backends that should be included in this set that don't have the dense key set.
426 // E.g. DispatchKey::Meta, DispatchKey::MAIA.
427 if (c10::isBackendDispatchKey(dispatch_key)) {
428 DispatchKey autograd_key = getAutogradKeyFromBackend(toBackendComponent(dispatch_key));
429 updateDispatchTableEntry_(dispatcher, autograd_key);
430 }
431 }
432
433 // does a complete update of the dispatch table, synchronizing all
434 // runtime dispatch keys with the current state of kernel registrations
435 // in the dispatcher.
436 // Note that we use updateDispatchTable_() to perform our per-key updating,
437 // even though that function is equipped to handle out-of-order updates and
438 // alias key updates, neither of which we send it. This is deliberate - the
439 // current design is more tractable with all updates funneled through a single
440 // per-key update mechanism, than with multiple variations that assume different
441 // invariants.
442 //
443 void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) {
444 // Note [Undefined in dispatchTable_]
445 // DispatchKey Undefined is used in runtime:
446 // (1) it gives people place to specify functionality that should run when there are no dispatch keys,
447 // e.g., an op without Tensor inputs or empty TensorList arguments
448 // (2) it would let us remove the explicit error checking code in the dispatch hotpath, and so when
449 // no dispatch keys are available we just slide into the undefined handler which would then raise
450 // the error message.
451 // In the old world of catchAll, the only way to "register" a kernel to Undefined is by registering it to
452 // catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either CompositeExplicitAutograd,
453 // or CompositeImplicitAutograd alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, CompositeImplicitAutograd)
454 // should return true, it returns false because Undefined cannot be represented in a DispatchKeySet.
455 updateDispatchTable_(dispatcher, DispatchKey::Undefined);
456 for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
457 updateDispatchTable_(dispatcher, k);
458 }
459 }
460
461 void OperatorEntry::checkInvariants() const {
462 if (schema_) {
463 TORCH_INTERNAL_ASSERT(schema_->schema.operator_name() == name_, dumpState());
464 dispatchKeyExtractor().checkInvariants(schema_->schema);
465 }
466 TORCH_INTERNAL_ASSERT(kernels_.find(DispatchKey::Undefined) == kernels_.end(), dumpState());
467 for (const auto& kv : kernels_) {
468 TORCH_INTERNAL_ASSERT(!kv.second.empty(), dumpState());
469 }
470 for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
471 auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), k);
472 auto idx = getDispatchTableIndexForDispatchKey(k);
473 if (C10_UNLIKELY(idx == -1)) {
474 continue;
475 }
476 TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[idx]),
477 "Canonical state\n~~~~~~~~~~~\n", dumpState(), "\n\n"
478 "Computed table:\n~~~~~~~~~~~\n", dumpComputedTable());
479 }
480 }
481
482 std::string OperatorEntry::listAllDispatchKeys() const {
483 std::ostringstream str;
484 str << "[";
485
486 bool has_kernels = false;
487 for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
488 auto iter = getDispatchTableIndexForDispatchKey(k);
489 if (iter == -1 || !dispatchTable_[iter].isValid()) {
490 continue;
491 }
492 if (has_kernels) {
493 str << ", ";
494 }
495 str << k;
496 has_kernels = true;
497 }
498 str << "]";
499 return str.str();
500 }
501
502 void OperatorEntry::reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const {
503 TORCH_CHECK(false,
504 "\nTried to access or call an operator with a wrong signature.\n",
505 " operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
506 " ", (schema_.has_value() ? schema_->debug : "unknown debug info"), "\n",
507 " correct signature: ", saved_signature.signature.name(), "\n",
508 " ", saved_signature.debug, "\n",
509 " accessed/called as: ", call_signature.name(), "\n",
510 "This likely happened in a call to OperatorHandle::typed<Return (Args...)>(). ",
511 "Please make sure that the function signature matches the signature in the operator registration call."
512 );
513 };
514
515 #ifndef STRIP_ERROR_MESSAGES
516 static std::string post_process_dispatch_key_str(std::string dispatch_key) {
517 const std::string substr = "PrivateUse1";
518 if (substr.size() <= dispatch_key.size() && std::equal(substr.rbegin(), substr.rend(), dispatch_key.rbegin())) {
519 auto privateuse1_backend = get_privateuse1_backend();
520 if (privateuse1_backend != "privateuseone") {
521 // remove trailing "*PrivateUse1"
522 dispatch_key.erase(dispatch_key.length() - substr.length());
523 // append the registered backend's name.
524 // AutogradPrivateUse1 -> AutogradFoo
525 auto backend_name = c10::get_privateuse1_backend();
526 dispatch_key = dispatch_key + backend_name;
527 }
528 }
529 return dispatch_key;
530 }
531 #endif
532
533 void OperatorEntry::reportError(DispatchKey dispatchKey) const {
534 // If there is an invariant problem, report it now.
535 checkInvariants();
536
537 if (report_error_callback_ != nullptr) {
538 report_error_callback_->pyinterpreter()->reportErrorCallback(report_error_callback_->ptr(&report_error_callback_->pyinterpreter()), dispatchKey);
539 // reportErrorCallback should have raised an error
540 TORCH_INTERNAL_ASSERT(false);
541 }
542 if (dispatchKey == DispatchKey::Undefined) {
543 TORCH_CHECK_NOT_IMPLEMENTED(false,
544 "There were no tensor arguments to this function (e.g., you passed an "
545 "empty list of Tensors), but no fallback function is registered for schema ", name_,
546 ". This usually means that this function requires a non-empty list of Tensors, "
547 "or that you (the operator writer) forgot to register a fallback function. "
548 "Available functions are ", listAllDispatchKeys(), ".\n\n", dumpComputedTable())
549 }
550
551 TORCH_CHECK_NOT_IMPLEMENTED(false, "Could not run '", name_, "' with arguments",
552 " from the '", post_process_dispatch_key_str(toString(dispatchKey)), "' backend. This could be because "
553 "the operator doesn't exist for this backend, or was omitted during ",
554 "the selective/custom build process (if using custom build). If you are a ",
555 "Facebook employee using PyTorch on mobile, please visit ",
556 "https://fburl.com/ptmfixes for possible resolutions. '",
557 name_, "' is only available for these backends: ",
558 listAllDispatchKeys(), ".\n\n", dumpComputedTable());
559 }
560
561 // INSPECTING DISPATCHER STATE
562 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~
563 // The dumper functions purposely do not check invariants, as you might be using
564 // them to debug situations where the invariants are violated.
565
566 // Inspect what the computed dispatch table would be (e.g., what
567 // updateDispatchTableFull_ would update the dispatch table to be)
568 std::string OperatorEntry::dumpComputedTable() const {
569 std::ostringstream oss;
570 // Need to handle Undefined separately, because its a runtime key that can't be represented
571 // in a DispatchKeySet.
572 std::vector<DispatchKey> runtime_keys = {DispatchKey::Undefined};
573 for (auto k : DispatchKeySet(DispatchKeySet::FULL)) runtime_keys.push_back(k);
574
575 for (auto k : runtime_keys) {
576 auto kernel_prov = computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);
577 if (kernel_prov.first.kernel.isValid()) {
578 oss << toString(k) << ": "
579 << (kernel_prov.first.kernel.isFallthrough() ? "fallthrough " : "")
580 << kernel_prov.first.debug << " [" << kernel_prov.second << "]\n";
581 }
582 }
583 return oss.str();
584 }
585
586 void OperatorEntry::setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback) {
587 report_error_callback_ = std::move(callback);
588 }
589
590 // Inspect the "canonical" information in OperatorEntry. This only prints out
591 // *non-derived* information including kernels registered to alias dispatch keys;
592 // i.e., what the source of truth says about the operator. This dumping function
593 // is appropriate for expect tests.
594 // This WON'T report backend fallbacks.
595 std::string OperatorEntry::dumpState() const {
596 std::ostringstream oss;
597 oss << "name: " << name_ << "\n";
598 if (schema_) {
599 oss << "schema: " << schema_->schema << "\n";
600 oss << "debug: " << schema_->debug << "\n";
601 oss << "alias analysis kind: " << toString(schema_->schema.aliasAnalysis())
602 << (schema_->schema.isDefaultAliasAnalysisKind() ? " (default)" : "") << "\n";
603 } else {
604 oss << "schema: (none)\n";
605 }
606
607 auto print_kernel = [&](const char* k_desc, const AnnotatedKernelContainer& jts, bool is_alias_key=false) {
608 int64_t i = 0;
609 for (const auto& jt : jts) {
610 oss << k_desc
611 << (is_alias_key ? "[alias]" : "")
612 << (i > 0 ? " (inactive)" : "")
613 << ": "
614 << jt.debug << " :: "
615 << (jt.inferred_function_schema ? toString(*jt.inferred_function_schema) : "(none)")
616 << " [ " << jt.kernel.dumpState() << "]\n";
617 i++;
618 }
619 };
620
621 // Iterate over DispatchKey, not the flat hash map, so we have a stable order
622 for (uint8_t i = 0; i <= static_cast<uint8_t>(DispatchKey::EndOfAliasKeys); i++) {
623 auto k = static_cast<DispatchKey>(i);
624 auto it = kernels_.find(k);
625 if (it != kernels_.end()) {
626 print_kernel(toString(k), it->second, c10::isAliasDispatchKey(k));
627 }
628 }
629 return oss.str();
630 }
631
632 }
633 }
634