xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/dispatch/DispatchKeyExtractor.h>
2 #include <c10/util/irange.h>
3 
4 #include <sstream>
5 
6 namespace c10 {
7 
setOperatorHasFallthroughForKey(DispatchKey k,bool has_fallthrough)8 void DispatchKeyExtractor::setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough) {
9   // (1) update nonFallthroughKeys_
10   if (has_fallthrough) {
11     nonFallthroughKeys_ = nonFallthroughKeys_.remove(k);
12   } else {
13     nonFallthroughKeys_ = nonFallthroughKeys_.add(k);
14   }
15   // (2) update nonFallthroughKeysPerBackend_
16   if (isPerBackendFunctionalityKey(toFunctionalityKey(k))) {
17     // This is a per-backend functionality key.
18     // We need to figure out what the current backend is,
19     // and only update the bitset for that backend.
20     // subtracting 1 because the first backend should have index 0 (CPU),
21     // But the enum starts with BackendComponent::InvalidBit.
22     auto backend_idx = static_cast<uint8_t>(toBackendComponent(k)) - 1;
23     TORCH_INTERNAL_ASSERT(backend_idx >= 0 && static_cast<uint8_t>(backend_idx) < nonFallthroughKeysPerBackend_.size());
24     if (has_fallthrough) {
25       nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].remove(k);
26     } else {
27       nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].add(k);
28     }
29 
30     // Set requiresBitsetPerBackend_ accordingly
31     for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size() - 1)) {
32       if (nonFallthroughKeysPerBackend_[i] != nonFallthroughKeysPerBackend_[i+1]) {
33         requiresBitsetPerBackend_ = true;
34         return;
35       }
36     }
37     requiresBitsetPerBackend_ = false;
38     return;
39   } else {
40     // Otherwise, if a fallthrough is set for a functionality that isn't per backend,
41     // Then we update the fallthrough bitset for EVERY backend.
42     // TODO: we could probably optimize this by only lazily updating these values
43     // the first time that we see requiresBitsetPerBackend_ = true
44     // (which should almost never happen)
45     if (has_fallthrough) {
46       for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
47         nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].remove(k);
48       }
49     } else {
50       for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
51         nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].add(k);
52       }
53     }
54   }
55 }
56 
dumpState() const57 std::string DispatchKeyExtractor::dumpState() const {
58   std::ostringstream oss;
59   for (const auto i : c10::irange(c10::utils::bitset::NUM_BITS())) {
60     if (dispatch_arg_indices_reverse_.get(i)) {
61       oss << "1";
62     } else {
63       oss << "0";
64     }
65   }
66   oss << " " << nonFallthroughKeys_ << "\n";
67   return oss.str();
68 }
69 
checkInvariants(const FunctionSchema & schema) const70 void DispatchKeyExtractor::checkInvariants(const FunctionSchema& schema) const {
71   TORCH_INTERNAL_ASSERT(makeBitsetForDispatchArgs(schema) == dispatch_arg_indices_reverse_);
72 }
73 
74 } // namespace c10
75