xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/dispatch/OperatorEntry.h>
2 #include <ATen/core/op_registration/infer_schema.h>
3 #include <ATen/core/dispatch/Dispatcher.h>
4 #include <ATen/core/dispatch/ObservedOperators.h>
5 
6 namespace c10 {
7 namespace impl {
8 
9 namespace {
10 #ifndef STRIP_ERROR_MESSAGES
toString(std::optional<DispatchKey> k)11   std::string toString(std::optional<DispatchKey> k) {
12     if (k.has_value()) {
13       return toString(*k);
14     } else {
15       return "(catch all)";
16     }
17   }
18 #endif
19 }
20 
OperatorEntry(OperatorName && operator_name)21 OperatorEntry::OperatorEntry(OperatorName&& operator_name)
22 : name_(std::move(operator_name))
23 , schema_()
24 #ifndef C10_MOBILE
25 , tags_()
26 #endif
27 , dispatchTable_()
28 , dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
29 , kernels_()
30 , cpp_signature_()
31 , sym_cpp_signature_()
32 , is_observed_(ObservedOperators::isObserved(name_))
33 {
34   // Pick up any backend fallbacks that were registered prior to this
35   // OperatorEntry being created
36   updateDispatchTableFull_(c10::Dispatcher::singleton());
37 }
38 
39 namespace {
checkSchema(const OperatorName & name,const FunctionSchema & from_def_,const std::string & from_def_debug,const KernelFunction & kernel,const FunctionSchema & inferred_,const std::string & inferred_debug)40   void checkSchema(const OperatorName& name, const FunctionSchema& from_def_, const std::string& from_def_debug, const KernelFunction& kernel, const FunctionSchema& inferred_, const std::string& inferred_debug) {
41     // TODO: figure out if we can just directly save real schema at def time
42     FunctionSchema from_def = from_def_.cloneWithRealTypes(kernel.isValidSymUnboxed());
43     FunctionSchema inferred = inferred_.cloneWithRealTypes();
44     std::optional<std::string> schema_difference = findSchemaDifferences(from_def, inferred);
45     if (schema_difference.has_value()) {
46       TORCH_CHECK(false,
47         "Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n"
48         "  operator: ", toString(name), "\n",
49         "  expected schema: ", toString(from_def), "\n",
50         "    ", from_def_debug, "\n",
51         "  inferred schema: ", toString(inferred), "\n",
52         "    ", inferred_debug, "\n",
53         "  reason: ", *schema_difference);
54     }
55   }
56 } // anonymous namespace
57 
missingKernel() const58 const AnnotatedKernel& OperatorEntry::missingKernel() const {
59   static AnnotatedKernel kernel;
60   return kernel;
61 }
62 
ambiguousAutogradOtherKernel() const63 const AnnotatedKernel& OperatorEntry::ambiguousAutogradOtherKernel() const {
64   static AnnotatedKernel kernel(
65     c10::KernelFunction::makeAmbiguousAutogradOther(), nullptr, "ambiguous_autogradother");
66   return kernel;
67 }
68 
assertSignatureIsCorrect(const CppSignature & call_signature,bool has_symint) const69 void OperatorEntry::assertSignatureIsCorrect(const CppSignature& call_signature, bool has_symint) const {
70   if (has_symint) {
71     if (C10_UNLIKELY(sym_cpp_signature_.has_value() && (call_signature != sym_cpp_signature_->signature))) {
72       reportSignatureError(call_signature, *sym_cpp_signature_);
73     }
74   } else {
75     if (C10_UNLIKELY(cpp_signature_.has_value() && (call_signature != cpp_signature_->signature))) {
76       reportSignatureError(call_signature, *cpp_signature_);
77     }
78   }
79 }
80 
registerSchema(FunctionSchema && schema,std::string && debug,std::vector<at::Tag> tags)81 void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug, std::vector<at::Tag> tags) {
82   TORCH_INTERNAL_ASSERT(!schema_.has_value());
83   for (const auto& kernel : kernels_) {
84     for (const auto &j : kernel.second) {
85       if (j.inferred_function_schema != nullptr) {
86         checkSchema(name_, schema, debug, j.kernel, *j.inferred_function_schema, j.debug);
87       }
88     }
89   }
90   // NB: don't register schema until after we've checked everything!
91   dispatchKeyExtractor_.registerSchema(schema);
92   schema_ = AnnotatedSchema(std::move(schema), std::move(debug));
93   #ifndef C10_MOBILE
94     tags_ = std::move(tags);
95   #endif
96 }
97 
deregisterSchema()98 void OperatorEntry::deregisterSchema() {
99   TORCH_INTERNAL_ASSERT(schema_.has_value());
100   schema_ = std::nullopt;
101   dispatchKeyExtractor_.deregisterSchema();
102 }
103 
registerKernel(const c10::Dispatcher & dispatcher,std::optional<DispatchKey> dispatch_key,KernelFunction kernel,std::optional<CppSignature> cpp_signature,std::unique_ptr<FunctionSchema> inferred_function_schema,std::string debug)104 OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
105   const c10::Dispatcher& dispatcher,
106   std::optional<DispatchKey> dispatch_key,
107   KernelFunction kernel,
108   std::optional<CppSignature> cpp_signature,
109   std::unique_ptr<FunctionSchema> inferred_function_schema,
110   std::string debug
111 ) {
112   // NB: cpp_signature doesn't get cleared even after the kernel that populated
113   // it is deleted.  This means you could poison the value of cpp_signature_
114   // with a bad signature value, and then it would permanently stay there until
115   // you deregister the schema.  This can't really be fixed, because we
116   // only do a typed() test once in the lifetime of a TypedOperatorHandle,
117   // which means if you could validly change the type of a cpp_signature, then
118   // that would also invalidate the old TypedOperatorHandles.
119   if (cpp_signature.has_value()) {
120     auto& local_cpp_signature = kernel.isValidSymUnboxed() ? sym_cpp_signature_ : cpp_signature_;
121     if (local_cpp_signature.has_value()) {
122       TORCH_CHECK(*cpp_signature == local_cpp_signature->signature,
123         "\nMismatch in kernel C++ signatures\n",
124         "  operator: ", (this->schema_.has_value() ? toString(this->schema_->schema) : toString(name_)), "\n",
125         "    ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
126         "  kernel 1: ", local_cpp_signature->signature.name(), "\n",
127         "    dispatch key: ", toString(local_cpp_signature->dispatch_key), "\n",
128         "    ", local_cpp_signature->debug, "\n",
129         "  kernel 2: ", cpp_signature->name(), "\n",
130         "    dispatch key: ", toString(dispatch_key), "\n",
131         "    ", debug, "\n"
132       );
133     } else {
134       local_cpp_signature = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key };
135     }
136   }
137 
138   if (schema_ && inferred_function_schema) {
139     checkSchema(name_, schema_->schema, schema_->debug, kernel, *inferred_function_schema, debug);
140   }
141 
142   // Add the kernel to the kernels list,
143   // possibly creating the list if this is the first kernel.
144   // Redirect catchAll registrations to CompositeImplicitAutograd.
145   auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::CompositeImplicitAutograd];
146 
147 #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
148   if (k[0].kernel.isValid()) {
149 #else
150   if (!k.empty()) {
151 #endif
152     // Suppress the warning for Meta key as we are overriding C++ meta functions with python meta functions
153     // for some ops
154     if (dispatch_key != DispatchKey::Meta) {
155       TORCH_WARN_ONCE("Warning only once for all operators,  other operators may also be overridden.\n",
156             "  Overriding a previously registered kernel for the same operator and the same dispatch key\n",
157             "  operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
158             "    ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
159             "  dispatch key: ", toString(dispatch_key), "\n",
160             "  previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : (sym_cpp_signature_.has_value() ? sym_cpp_signature_->debug : "no debug info")), "\n",
161             "       new kernel: ", debug
162       );
163     }
164   }
165 
166 #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
167   k[0].kernel = std::move(kernel);
168   k[0].inferred_function_schema = std::move(inferred_function_schema);
169   k[0].debug = std::move(debug);
170 #else
171   k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug));
172 #endif
173   AnnotatedKernelContainerIterator inserted = k.begin();
174   // update the dispatch table, i.e. re-establish the invariant
175   // that the dispatch table points to the newest kernel
176   if (dispatch_key.has_value()) {
177     updateDispatchTable_(dispatcher, *dispatch_key);
178   } else {
179     updateDispatchTableFull_(dispatcher);
180   }
181   return inserted;
182 }
183 
184 void OperatorEntry::deregisterKernel_(
185   const c10::Dispatcher& dispatcher,
186   std::optional<DispatchKey> dispatch_key,
187   AnnotatedKernelContainerIterator kernel
188 ) {
189   // Redirect catchAll deregistrations to CompositeImplicitAutograd.
190   DispatchKey dk = dispatch_key.has_value() ? *dispatch_key : DispatchKey::CompositeImplicitAutograd;
191   auto found = kernels_.find(dk);
192   TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_));
193   auto& k = found->second;
194 #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
195   // We are about to remove the array from the map, no need to do anything.
196 #else
197   k.erase(kernel);
198 #endif
199   if (k.empty()) {
200     // the invariant says we don't want empty lists but instead remove the list from the map
201     kernels_.erase(found);
202   }
203   updateDispatchTable_(dispatcher, dk);
204 }
205 
206 void OperatorEntry::updateFallback(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
207   updateDispatchTable_(dispatcher, dispatch_key);
208 }
209 
210 const KernelFunction& OperatorEntry::computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const {
211   return computeDispatchTableEntryWithDebug(dispatcher, dispatch_key).first.kernel;
212 }
213 
214 bool OperatorEntry::hasKernelForAnyDispatchKey(DispatchKeySet ks) const {
215   TORCH_INTERNAL_ASSERT(kernels_.find(DispatchKey::Undefined) == kernels_.end());
216   for (auto& kv : kernels_) {
217     // Note [No Alias Keys in DispatchKeySet]
218     if (!isAliasDispatchKey(kv.first) && ks.has(kv.first)) return true;
219   }
220   return false;
221 }
222 
223 bool OperatorEntry::hasKernelForDispatchKey(DispatchKey k) const {
224   TORCH_INTERNAL_ASSERT(kernels_.find(DispatchKey::Undefined) == kernels_.end());
225   auto it = kernels_.find(k);
226   if (it == kernels_.end()) return false;
227   return !it->second.empty();
228 }
229 
230 const KernelFunction& OperatorEntry::kernelForDispatchKey(DispatchKey k) const {
231   auto it = kernels_.find(k);
232   TORCH_CHECK(it != kernels_.end() && !it->second.empty(), "no kernel for ", k, " on ", name_);
233   auto jt = it->second.begin();
234   TORCH_INTERNAL_ASSERT(jt->kernel.isValid())
235   return jt->kernel;
236 }
237 
238 bool OperatorEntry::hasComputedKernelForDispatchKey(DispatchKey k) const {
239   TORCH_CHECK(!isAliasDispatchKey(k), "Alias keys do not have runtime kernel registrations.");
240   const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
241   TORCH_INTERNAL_ASSERT(dispatch_ix >= 0 && dispatch_ix < c10::num_runtime_entries, toString(k), dispatch_ix);
242   return dispatchTable_[dispatch_ix].isValid();
243 }
244 
245 const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{
246   auto kern_it = kernels_.find(dispatch_key);
247   if (kern_it != kernels_.end()) {
248     TORCH_INTERNAL_ASSERT(!kern_it->second.empty());
249     TORCH_INTERNAL_ASSERT(kern_it->second.front().kernel.isValid());
250     return &kern_it->second.front();
251   }
252   return nullptr;
253 }
254 
255 const std::vector<at::Tag>& OperatorEntry::getTags() const {
256   #if defined C10_MOBILE
257     TORCH_CHECK(false, "tags are not saved for Mobile");
258   #else
259     return tags_;
260   #endif
261 }
262 
263 std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTableEntryWithDebug(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const {
264   // [Note] DispatchTable computation
265   // dispatchTable contains entries for runtime dispatch keys.
266   // For any dispatch key, it'll pick a kernel using the following order:
267   //  (1) Use kernel if it's directly registered to this key
268   //  (2) Handle runtime keys that have kernels available from alias keys
269   //    (2.1) Use kernel from DispatchKey::CompositeExplicitAutogradNonFunctional if available.
270   //          This is used to register a kernel that works for all backends in inference, except "functional" backends
271   //          like LazyTensor/XLA. But it requires separate registration for Autograd keys to support training.
272   //    (2.2) Use kernel from DispatchKey::CompositeExplicitAutograd if available.
273   //          This is used to register a kernel that works for all backend in inference. But it requires
274   //          separate registration for Autograd keys to support training.
275   //    (2.3) Use kernel from DispatchKey::CompositeImplicitAutograd if available.
276   //          For autograd keys, we only use kernel from CompositeImplicitAutograd when there's no direct registration
277   //          to its corresponding backend key or CompositeExplicitAutograd. See Note [CompositeExplicitAutograd and CompositeImplicitAutograd].
278   //          For AutogradOther, we eagerly return ambiguousAutogradOtherKernel() if there's registration to any of
279   //          its backends and ask backend extender to request a decicated Autograd key for the backend.
280   //          See Note [Ambiguity in AutogradOther kernel] for more details.
281   //          A CompositeExplicitAutograd kernel prevents CompositeImplicitAutograd kernel being used for Autograd keys, but it doesn't
282   //          cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available)
283   //          in this case.
284   //    (2.4) Use kernel from DispatchKey::Autograd if available
285   //    (2.5) Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available
286   //    The implementation of (2.2) relies on the invariant that for a given backend,
287   //    `computeDispatchTableEntryWithDebug()` will be called for that backend's autograd key after the
288   //    backend key. See Note [Refresh Runtime Autograd entries in dispatchTable_]
289   //  (3) Use fallthrough kernel that are registered as fallback.
290   // Alias Key Precedence:
291   //   CompositExplicitAutogradNonFunctional > CompositeExplicitAutograd > CompositeImplicitAutograd > Autograd
292   // Note [CompositeExplicitAutograd and CompositeImplicitAutograd]
293   //   When there're registrations to both CompositeExplicitAutograd & CompositeImplicitAutograd & Autograd, from (2.2) we know CompositeExplicitAutograd
294   //   and Autograd kernels will be picked up and CompositeImplicitAutograd is overriden.
295   //   This is fine and in practice CompositeExplicitAutograd and CompositeImplicitAutograd shouldn't co-exist for an op.
296   // TODO: Update alias key precedence after we add new alias keys AutogradDispatchCPUOrCUDA .
297 
298   // 1. Operator registration
299   if (auto direct_registration = getKernelForDispatchKey(dispatch_key)) {
300     return {*direct_registration, "kernel"};
301   }
302 
303   // 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available.
304   //     See Note [Undefined in dispatchTable_] for the special handling for Undefined.
305   if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeExplicitAutogradNonFunctional)) {
306     if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::CompositeExplicitAutogradNonFunctional)) {
307       return {*default_backend_registration, "default backend kernel"};
308     }
309   }
310 
311   // 2.2 Use CompositeExplicitAutograd kernel if available.
312   //     See Note [Undefined in dispatchTable_] for the special handling for Undefined.
313   if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeExplicitAutograd)) {
314     if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::CompositeExplicitAutograd)) {
315       return {*default_backend_registration, "default backend kernel"};
316     }
317   }
318 
319   // Note when there's direct registration to CompositeExplicitAutograd, this code path will only be hit by
320   // non backend keys (e.g AutogradXXX, Batched etc) due to (2.1).
321   bool has_backend_kernel =
322     hasKernelForAnyDispatchKey(getBackendKeySetFromAutograd(dispatch_key)) ||
323     // See Note [No Alias Keys in DispatchKeySet]
324     hasKernelForDispatchKey(DispatchKey::CompositeExplicitAutograd);
325 
326   // 2.3. Use CompositeImplicitAutograd kernel if available. For autograd keys, we only use kernel from CompositeImplicitAutograd
327   //      when there's no direct registration to its corresponding backend key or CompositeExplicitAutograd.
328   //      For AutogradOther, we return ambiguousAutogradOtherKernel() if there's registration
329   //      to any of its backends.
330   //      See Note [Undefined in dispatchTable_] for the special handling for Undefined.
331 
332   // If the dispatch key is included in CompositeImplicitAutogradNestedTensor,
333   // then we register it to nested-tensor kernel rather than
334   // regular-tensor CompositeImplicitAutograd kernel.
335   // We have no intention to change the behavior of Undefined,
336   // so this nested-tensor branch requires `dispatch_key != DispatchKey::Undefined`
337   // to let the original CompositeImplicitAutograd handle Undefined
338   // See Note: [Disjoint AliasKeyset] The order for this alias key doesn't matter
339   if (dispatch_key != DispatchKey::Undefined && isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutogradNestedTensor)) {
340     if (auto nested_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutogradNestedTensor)) {
341       return {*nested_registration, "nested kernel"};
342       }
343   }
344 
345   if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutograd)) {
346     if (auto math_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutograd)) {
347       if (dispatch_key == DispatchKey::AutogradOther
348           && hasKernelForAnyDispatchKey(c10::autogradother_backends)) {
349         return {ambiguousAutogradOtherKernel(), "ambiguous autogradother"};
350       } else if (!has_backend_kernel) {
351         return {*math_registration, "math kernel"};
352       }
353     }
354   }
355 
356   // 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
357   if (isIncludedInAlias(dispatch_key, DispatchKey::Autograd)) {
358     if (auto autograd_registration = getKernelForDispatchKey(DispatchKey::Autograd)) {
359       return {*autograd_registration, "autograd kernel"};
360     }
361   }
362 
363   // 2.5. For batched backend keys, use kernel from DispatchKey::FuncTorchBatchedDecomposition if available
364   // See Note: [Disjoint AliasKeyset] The order for this alias key doesn't matter
365   if (isIncludedInAlias(dispatch_key, DispatchKey::FuncTorchBatchedDecomposition)) {
366     if (auto batched_registration = getKernelForDispatchKey(DispatchKey::FuncTorchBatchedDecomposition)) {
367       return {*batched_registration, "batched kernel"};
368     }
369   }
370 
371   // 3. Backend fallback
372   auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
373   if (dispatch_ix < 0) {
374     return {missingKernel(), "backend fallback not registered on mobile"};
375   }
376   if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) {
377     return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"};
378   }
379 
380   // 4. Default to error
381   return {missingKernel(), "missing"};
382 }
383 
384 // synchronizes the dispatch table entry for a given dispatch key
385 // with the current state of kernel registrations in the dispatcher.
386 // note that this is not a complete update, due to relationships between
387 // dispatch keys (e.g. runtime keys and their associated autograd keys,
388 // or alias keys and their associated keysets).
389 // This function should be considered a private helper for updateDispatchTable_()
390 void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
391   const auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
392   if (C10_UNLIKELY(dispatch_ix == -1)) {
393     return;
394   }
395   dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
396   dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
397 }
398 
399 // synchronizes the dispatch table entries for a given dispatch key *and its
400 // associated keys* with the current state of kernel registrations in the
401 // dispatcher.
402 // After a kernel has been registered to a dispatch key, a call to this
403 // function will synchronize the dispatcher state. See e.g. registerKernel()
404 void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
405   // Handle Undefined separately since it isn't a runtime key but we have an entry in dispatchTable_.
406   // See Note [Undefined in dispatchTable_]
407   if (dispatch_key == DispatchKey::Undefined) {
408     updateDispatchTableEntry_(dispatcher, dispatch_key);
409     return;
410   }
411   for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) {
412     updateDispatchTableEntry_(dispatcher, k);
413   }
414   // Registration to CompositeExplicitAutogradNonFunctional, CompositeExplicitAutograd and CompositeImplicitAutograd should be populated to Undefined.
415   // We cannot do this above since Undefined cannot be represented in DispatchKeySet.
416   if (dispatch_key == DispatchKey::CompositeImplicitAutograd
417    || dispatch_key == DispatchKey::CompositeExplicitAutograd
418    || dispatch_key == DispatchKey::CompositeExplicitAutogradNonFunctional) {
419     updateDispatchTableEntry_(dispatcher, DispatchKey::Undefined);
420   }
421   // Note [Refresh Runtime Autograd entries in dispatchTable_]
422   // Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
423   // In theory, we should only have to check if the given runtime key has "dense" functionality,
424   // e.g. DispatchKey::CPU (which is composed of DispatchKey::Dense and BackendComponent::CPUBit).
425   // However, there are some backends that should be included in this set that don't have the dense key set.
426   // E.g. DispatchKey::Meta, DispatchKey::MAIA.
427   if (c10::isBackendDispatchKey(dispatch_key)) {
428     DispatchKey autograd_key = getAutogradKeyFromBackend(toBackendComponent(dispatch_key));
429     updateDispatchTableEntry_(dispatcher, autograd_key);
430   }
431 }
432 
433 // does a complete update of the dispatch table, synchronizing all
434 // runtime dispatch keys with the current state of kernel registrations
435 // in the dispatcher.
436 // Note that we use updateDispatchTable_() to perform our per-key updating,
437 // even though that function is equipped to handle out-of-order updates and
438 // alias key updates, neither of which we send it. This is deliberate - the
439 // current design is more tractable with all updates funneled through a single
440 // per-key update mechanism, than with multiple variations that assume different
441 // invariants.
442 //
443 void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) {
444   // Note [Undefined in dispatchTable_]
445   // DispatchKey Undefined is used in runtime:
446   // (1) it gives people place to specify functionality that should run when there are no dispatch keys,
447   //     e.g., an op without Tensor inputs or empty TensorList arguments
448   // (2) it would let us remove the explicit error checking code in the dispatch hotpath, and so when
449   //     no dispatch keys are available we just slide into the undefined handler which would then raise
450   //     the error message.
451   // In the old world of catchAll, the only way to "register" a kernel to Undefined is by registering it to
452   // catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either CompositeExplicitAutograd,
453   // or CompositeImplicitAutograd alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, CompositeImplicitAutograd)
454   // should return true, it returns false because Undefined cannot be represented in a DispatchKeySet.
455   updateDispatchTable_(dispatcher, DispatchKey::Undefined);
456   for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
457     updateDispatchTable_(dispatcher, k);
458   }
459 }
460 
461 void OperatorEntry::checkInvariants() const {
462   if (schema_) {
463     TORCH_INTERNAL_ASSERT(schema_->schema.operator_name() == name_, dumpState());
464     dispatchKeyExtractor().checkInvariants(schema_->schema);
465   }
466   TORCH_INTERNAL_ASSERT(kernels_.find(DispatchKey::Undefined) == kernels_.end(), dumpState());
467   for (const auto& kv : kernels_) {
468     TORCH_INTERNAL_ASSERT(!kv.second.empty(), dumpState());
469   }
470   for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
471     auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), k);
472     auto idx = getDispatchTableIndexForDispatchKey(k);
473     if (C10_UNLIKELY(idx == -1)) {
474       continue;
475     }
476     TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[idx]),
477       "Canonical state\n~~~~~~~~~~~\n", dumpState(), "\n\n"
478       "Computed table:\n~~~~~~~~~~~\n", dumpComputedTable());
479   }
480 }
481 
482 std::string OperatorEntry::listAllDispatchKeys() const {
483   std::ostringstream str;
484   str << "[";
485 
486   bool has_kernels = false;
487   for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
488     auto iter = getDispatchTableIndexForDispatchKey(k);
489     if (iter == -1 || !dispatchTable_[iter].isValid()) {
490       continue;
491     }
492     if (has_kernels) {
493       str << ", ";
494     }
495     str << k;
496     has_kernels = true;
497   }
498   str << "]";
499   return str.str();
500 }
501 
502 void OperatorEntry::reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const {
503   TORCH_CHECK(false,
504         "\nTried to access or call an operator with a wrong signature.\n",
505         "  operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
506         "    ", (schema_.has_value() ? schema_->debug : "unknown debug info"), "\n",
507         "  correct signature:  ", saved_signature.signature.name(), "\n",
508         "    ", saved_signature.debug, "\n",
509         "  accessed/called as: ", call_signature.name(), "\n",
510         "This likely happened in a call to OperatorHandle::typed<Return (Args...)>(). ",
511         "Please make sure that the function signature matches the signature in the operator registration call."
512   );
513 };
514 
515 #ifndef STRIP_ERROR_MESSAGES
516 static std::string post_process_dispatch_key_str(std::string dispatch_key) {
517   const std::string substr = "PrivateUse1";
518   if (substr.size() <= dispatch_key.size() && std::equal(substr.rbegin(), substr.rend(), dispatch_key.rbegin())) {
519     auto privateuse1_backend = get_privateuse1_backend();
520     if (privateuse1_backend != "privateuseone") {
521       // remove trailing "*PrivateUse1"
522       dispatch_key.erase(dispatch_key.length() - substr.length());
523       // append the registered backend's name.
524       // AutogradPrivateUse1 -> AutogradFoo
525       auto backend_name = c10::get_privateuse1_backend();
526       dispatch_key = dispatch_key + backend_name;
527     }
528   }
529   return dispatch_key;
530 }
531 #endif
532 
533 void OperatorEntry::reportError(DispatchKey dispatchKey) const {
534   // If there is an invariant problem, report it now.
535   checkInvariants();
536 
537   if (report_error_callback_ != nullptr) {
538     report_error_callback_->pyinterpreter()->reportErrorCallback(report_error_callback_->ptr(&report_error_callback_->pyinterpreter()), dispatchKey);
539     // reportErrorCallback should have raised an error
540     TORCH_INTERNAL_ASSERT(false);
541   }
542   if (dispatchKey == DispatchKey::Undefined) {
543     TORCH_CHECK_NOT_IMPLEMENTED(false,
544           "There were no tensor arguments to this function (e.g., you passed an "
545           "empty list of Tensors), but no fallback function is registered for schema ", name_,
546           ".  This usually means that this function requires a non-empty list of Tensors, "
547           "or that you (the operator writer) forgot to register a fallback function.  "
548           "Available functions are ", listAllDispatchKeys(), ".\n\n", dumpComputedTable())
549   }
550 
551   TORCH_CHECK_NOT_IMPLEMENTED(false, "Could not run '", name_, "' with arguments",
552           " from the '", post_process_dispatch_key_str(toString(dispatchKey)), "' backend. This could be because "
553           "the operator doesn't exist for this backend, or was omitted during ",
554           "the selective/custom build process (if using custom build). If you are a ",
555           "Facebook employee using PyTorch on mobile, please visit ",
556           "https://fburl.com/ptmfixes for possible resolutions. '",
557           name_, "' is only available for these backends: ",
558           listAllDispatchKeys(), ".\n\n", dumpComputedTable());
559 }
560 
561 // INSPECTING DISPATCHER STATE
562 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~
563 // The dumper functions purposely do not check invariants, as you might be using
564 // them to debug situations where the invariants are violated.
565 
566 // Inspect what the computed dispatch table would be (e.g., what
567 // updateDispatchTableFull_ would update the dispatch table to be)
568 std::string OperatorEntry::dumpComputedTable() const {
569   std::ostringstream oss;
570   // Need to handle Undefined separately, because its a runtime key that can't be represented
571   // in a DispatchKeySet.
572   std::vector<DispatchKey> runtime_keys = {DispatchKey::Undefined};
573   for (auto k : DispatchKeySet(DispatchKeySet::FULL)) runtime_keys.push_back(k);
574 
575   for (auto k : runtime_keys) {
576     auto kernel_prov = computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);
577     if (kernel_prov.first.kernel.isValid()) {
578       oss << toString(k) << ": "
579           << (kernel_prov.first.kernel.isFallthrough() ? "fallthrough " : "")
580           << kernel_prov.first.debug << " [" << kernel_prov.second << "]\n";
581     }
582   }
583   return oss.str();
584 }
585 
586 void OperatorEntry::setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback) {
587   report_error_callback_ = std::move(callback);
588 }
589 
590 // Inspect the "canonical" information in OperatorEntry.  This only prints out
591 // *non-derived* information including kernels registered to alias dispatch keys;
592 // i.e., what the source of truth says about the operator.  This dumping function
593 // is appropriate for expect tests.
594 // This WON'T report backend fallbacks.
595 std::string OperatorEntry::dumpState() const {
596   std::ostringstream oss;
597   oss << "name: " << name_ << "\n";
598   if (schema_) {
599     oss << "schema: " << schema_->schema << "\n";
600     oss << "debug: " << schema_->debug << "\n";
601     oss << "alias analysis kind: " << toString(schema_->schema.aliasAnalysis())
602         << (schema_->schema.isDefaultAliasAnalysisKind() ? " (default)" : "") << "\n";
603   } else {
604     oss << "schema: (none)\n";
605   }
606 
607   auto print_kernel = [&](const char* k_desc, const AnnotatedKernelContainer& jts, bool is_alias_key=false) {
608     int64_t i = 0;
609     for (const auto& jt : jts) {
610       oss << k_desc
611           << (is_alias_key ? "[alias]" :  "")
612           << (i > 0 ? " (inactive)" : "")
613           << ": "
614           << jt.debug << " :: "
615           << (jt.inferred_function_schema ? toString(*jt.inferred_function_schema) : "(none)")
616           << " [ " << jt.kernel.dumpState() << "]\n";
617       i++;
618     }
619   };
620 
621   // Iterate over DispatchKey, not the flat hash map, so we have a stable order
622   for (uint8_t i = 0; i <= static_cast<uint8_t>(DispatchKey::EndOfAliasKeys); i++) {
623     auto k = static_cast<DispatchKey>(i);
624     auto it = kernels_.find(k);
625     if (it != kernels_.end()) {
626       print_kernel(toString(k), it->second, c10::isAliasDispatchKey(k));
627     }
628   }
629   return oss.str();
630 }
631 
632 }
633 }
634