1 #include <ATen/core/dispatch/DispatchKeyExtractor.h> 2 #include <c10/util/irange.h> 3 4 #include <sstream> 5 6 namespace c10 { 7 setOperatorHasFallthroughForKey(DispatchKey k,bool has_fallthrough)8void DispatchKeyExtractor::setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough) { 9 // (1) update nonFallthroughKeys_ 10 if (has_fallthrough) { 11 nonFallthroughKeys_ = nonFallthroughKeys_.remove(k); 12 } else { 13 nonFallthroughKeys_ = nonFallthroughKeys_.add(k); 14 } 15 // (2) update nonFallthroughKeysPerBackend_ 16 if (isPerBackendFunctionalityKey(toFunctionalityKey(k))) { 17 // This is a per-backend functionality key. 18 // We need to figure out what the current backend is, 19 // and only update the bitset for that backend. 20 // subtracting 1 because the first backend should have index 0 (CPU), 21 // But the enum starts with BackendComponent::InvalidBit. 22 auto backend_idx = static_cast<uint8_t>(toBackendComponent(k)) - 1; 23 TORCH_INTERNAL_ASSERT(backend_idx >= 0 && static_cast<uint8_t>(backend_idx) < nonFallthroughKeysPerBackend_.size()); 24 if (has_fallthrough) { 25 nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].remove(k); 26 } else { 27 nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].add(k); 28 } 29 30 // Set requiresBitsetPerBackend_ accordingly 31 for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size() - 1)) { 32 if (nonFallthroughKeysPerBackend_[i] != nonFallthroughKeysPerBackend_[i+1]) { 33 requiresBitsetPerBackend_ = true; 34 return; 35 } 36 } 37 requiresBitsetPerBackend_ = false; 38 return; 39 } else { 40 // Otherwise, if a fallthrough is set for a functionality that isn't per backend, 41 // Then we update the fallthrough bitset for EVERY backend. 42 // TODO: we could probably optimize this by only lazily updating these values 43 // the first time that we see requiresBitsetPerBackend_ = true 44 // (which should almost never happen) 45 if (has_fallthrough) { 46 for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) { 47 nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].remove(k); 48 } 49 } else { 50 for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) { 51 nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].add(k); 52 } 53 } 54 } 55 } 56 dumpState() const57std::string DispatchKeyExtractor::dumpState() const { 58 std::ostringstream oss; 59 for (const auto i : c10::irange(c10::utils::bitset::NUM_BITS())) { 60 if (dispatch_arg_indices_reverse_.get(i)) { 61 oss << "1"; 62 } else { 63 oss << "0"; 64 } 65 } 66 oss << " " << nonFallthroughKeys_ << "\n"; 67 return oss.str(); 68 } 69 checkInvariants(const FunctionSchema & schema) const70void DispatchKeyExtractor::checkInvariants(const FunctionSchema& schema) const { 71 TORCH_INTERNAL_ASSERT(makeBitsetForDispatchArgs(schema) == dispatch_arg_indices_reverse_); 72 } 73 74 } // namespace c10 75