xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 #include <ATen/core/function_schema.h>
5 #include <ATen/core/jit_type.h>
6 #include <c10/util/Bitset.h>
7 #include <c10/core/DispatchKeySet.h>
8 #include <c10/util/irange.h>
9 #include <ATen/core/Variadic.h>
10 #include <ATen/core/stack.h>
11 
12 namespace c10 {
13 
14 namespace impl {
15 
16 // Take a DispatchKeySet for a Tensor and determine what the actual dispatch
17 // DispatchKey should be, taking into account TLS, and skipping backends which
18 // fall through.
19 //
20 // Unlike Tensor::key_set(), the value of this on a tensor can change depending
21 // on TLS.
22 //
23 // NB: If there is no valid dispatch key, this will return Undefined
computeDispatchKeySet(DispatchKeySet ks,DispatchKeySet key_mask)24 inline DispatchKeySet computeDispatchKeySet(
25     DispatchKeySet ks,
26     // The key mask lets us eliminate (by zero entries) keys which should not
27     // be considered for dispatch.  There are two cases when we use this:
28     //
29     // - If an operator's dispatch table contains a fallthrough entry, we
30     //   should bypass it entirely when finding the key
31     // - If a user invokes with redispatch, the mask lets us
32     //   zero out the key the user asked us to stop.
33     //
34     // These excluded backends are NOT tracked in the TLS, but must be applied
35     // AFTER TLS (since the backend may have been introduced for consideration
36     // by the included TLS), which is why you have to pass them in to this
37     // function (as opposed to just applying it to the input 'ks').
38     DispatchKeySet key_mask
39 ) {
40   c10::impl::LocalDispatchKeySet local = c10::impl::tls_local_dispatch_key_set();
41   // TODO: It's a bit irritating that we have to do logical ORs here, it would
42   // be nice to only do one.  Can always_included be folded into the TLS?  Well,
43   // it's a bit troublesome, because fastpath TLS access requires the type of
44   // the TLS in question to be zero-initialized, so you don't actually win
45   // anything in that case.
46   return (((ks | local.included_) - local.excluded_) & key_mask);
47 }
48 
49 }
50 
51 namespace detail {
52   // A small gadget to extract the DispatchKeySet from types which are known
53   // to have it.  Used to extract dispatch keys from unboxed calls.
54   struct MultiDispatchKeySet : at::IterArgs<MultiDispatchKeySet> {
55     DispatchKeySet ts;
operatorMultiDispatchKeySet56     void operator()(const at::Tensor& x) {
57       ts = ts | x.key_set();
58     }
operatorMultiDispatchKeySet59     void operator()(const std::optional<at::Tensor>& x) {
60       if (x.has_value()) {
61         ts = ts | x->key_set();
62       }
63     }
operatorMultiDispatchKeySet64     void operator()(at::ArrayRef<at::Tensor> xs) {
65       for (const auto& x : xs) {
66         ts = ts | x.key_set();
67       }
68     }
69     // Tensor?[] translates to this case.
operatorMultiDispatchKeySet70     void operator()(const c10::List<std::optional<at::Tensor>>& xs) {
71       for (std::optional<at::Tensor> x : xs) {
72         if (x.has_value()) {
73           ts = ts | x.value().key_set();
74         }
75       }
76     }
77     // Structured Tensor[] translates to this case
operatorMultiDispatchKeySet78     void operator()(const at::ITensorListRef& xs) {
79       for (const auto& x : xs) {
80         ts = ts | x.key_set();
81       }
82     }
operatorMultiDispatchKeySet83     [[noreturn]] void operator()(at::ArrayRef<std::optional<at::Tensor>>) {
84       // Just checking that the handling of Tensor?[] didn't change.
85       TORCH_INTERNAL_ASSERT(false);
86     }
operatorMultiDispatchKeySet87     void operator()(const at::Generator& gen) {
88       if (gen.defined()) {
89         ts = ts | gen.key_set();
90       }
91     }
operatorMultiDispatchKeySet92     void operator()(const std::optional<at::Generator>& gen) {
93       if (gen.has_value() && gen->defined()) {
94         ts = ts | gen->key_set();
95       }
96     }
97     template <typename T>
operatorMultiDispatchKeySet98     void operator()(const T&) {
99       // do nothing
100     }
101   };
102 
103   // NB: take by const reference (Don't do universal forwarding here! You
104   // don't want to move into this function!)
105   template <typename... Args>
multi_dispatch_key_set(const Args &...args)106   DispatchKeySet multi_dispatch_key_set(const Args&... args) {
107     return MultiDispatchKeySet().apply(args...).ts;
108   }
109 }
110 
111 /**
112  * An instance of DispatchKeyExtractor knows how to get a dispatch key given
113  * a list of arguments for an operator call.
114  *
115  * The instance is specific for a certain operator as:
116  *  - In boxed dispatch, different operators have different ways to extract
117  *    the dispatch key (e.g. different numbers of arguments), and we precompute
118  *    the stack locations we should look at; and
119  *  - In all dispatch, some backends should be excluded from dispatch because
120  *    they have been registered as fallthrough.  The set of excluded backends
121  *    varies from operator, as some operators may have overridden the
122  *    fallthrough with custom behavior.
123  *
124  *   Note - this should maintain identical impl to the py dispatcher key extraction logic
125  *   at pytorch/torch/dispatcher.py
126  */
127 struct TORCH_API DispatchKeyExtractor final {
128 public:
makefinal129   static DispatchKeyExtractor make(const FunctionSchema& schema) {
130     return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema));
131   }
132 
makeUninitializedfinal133   static DispatchKeyExtractor makeUninitialized() {
134     return DispatchKeyExtractor(c10::utils::bitset());
135   }
136 
registerSchemafinal137   void registerSchema(const FunctionSchema& schema) {
138     TORCH_INTERNAL_ASSERT(dispatch_arg_indices_reverse_.is_entirely_unset());
139     dispatch_arg_indices_reverse_ = makeBitsetForDispatchArgs(schema);
140   }
deregisterSchemafinal141   void deregisterSchema() {
142     dispatch_arg_indices_reverse_ = c10::utils::bitset();
143   }
144 
getDispatchKeySetBoxedfinal145   DispatchKeySet getDispatchKeySetBoxed(const torch::jit::Stack* stack) const {
146     DispatchKeySet ks;
147     dispatch_arg_indices_reverse_.for_each_set_bit([&] (size_t reverse_arg_index) {
148       const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1);
149       if (C10_LIKELY(ivalue.isTensor())) {
150         // NB: Take care not to introduce a refcount bump (there's
151         // no safe toTensorRef method, alas)
152         ks = ks | ivalue.unsafeToTensorImpl()->key_set();
153       } else if (C10_UNLIKELY(ivalue.isTensorList())) {
154         for (const at::Tensor& tensor : ivalue.toTensorList()) {
155           ks = ks | tensor.key_set();
156         }
157       }
158       // Tensor?[] translates to a c10::List<IValue> so we need to peek inside
159       else if (C10_UNLIKELY(ivalue.isList())) {
160         for (const auto& elt : ivalue.toListRef()) {
161           if (elt.isTensor()) {
162             ks = ks | elt.toTensor().key_set();
163           }
164         }
165       }
166     });
167     // Keys that are fallthrough should be skipped
168     if (requiresBitsetPerBackend_) {
169       auto backend_idx = ks.getBackendIndex();
170       return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
171     } else {
172       return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
173     }
174   }
175 
176   template<class... Args>
getDispatchKeySetUnboxedfinal177   DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const {
178     auto ks = detail::multi_dispatch_key_set(args...);
179     // Keys that are fallthrough should be skipped
180     if (requiresBitsetPerBackend_) {
181       auto backend_idx = ks.getBackendIndex();
182       return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
183     } else {
184       return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
185     }
186   }
187 
188   void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough);
189 
190   std::string dumpState() const;
191   void checkInvariants(const FunctionSchema& schema) const;
192 
193 private:
makeBitsetForDispatchArgsfinal194   static c10::utils::bitset makeBitsetForDispatchArgs(const FunctionSchema& schema) {
195     TORCH_CHECK(schema.arguments().size() <= c10::utils::bitset::NUM_BITS(),
196         "The function schema has ", schema.arguments().size(),
197         " arguments but this PyTorch build only supports ", c10::utils::bitset::NUM_BITS());
198     c10::utils::bitset dispatch_arg_indices_reverse;
199     for (const auto index : c10::irange(schema.arguments().size())) {
200       if (schema.arguments()[index].type()->isSubtypeOf(*TensorType::get()) ||
201           schema.arguments()[index].type()->isSubtypeOf(
202               *ListType::ofTensors()) ||
203           schema.arguments()[index].type()->isSubtypeOf(
204               *ListType::ofOptionalTensors()) ||
205           schema.arguments()[index].type()->isSubtypeOf(
206               *OptionalType::ofTensor())) {
207         dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index);
208       }
209     }
210     return dispatch_arg_indices_reverse;
211   }
212 
DispatchKeyExtractorfinal213   explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
214   : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse)
215   , nonFallthroughKeys_(DispatchKeySet::FULL)
216   , requiresBitsetPerBackend_(false) {
217     for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
218       nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL;
219     }
220   }
221 
222   // this is a bitset that has ones for each argument index which has to be
223   // considered for dispatch. This avoids having to iterate over the stack
224   // to find all the tensors. The bits are stored in reverse order, i.e.
225   // dispatch_arg_indices_reverse_[i] == true, then the i-th argument from
226   // the top of the stack (i.e. the i-th last argument of the function)
227   // is relevant for dispatch.
228   // dispatch_arg_indices_reverse_ is allowed to have zero bits set; that just means you must do the
229   // fallthrough
230   c10::utils::bitset dispatch_arg_indices_reverse_;
231 
232   // Set of functionality keys for which the operator does NOT have fallthrough kernel.
233   DispatchKeySet nonFallthroughKeys_;
234   // Set of functionality keys for which the operator does NOT have fallthrough kernel, defined PER BACKEND.
235   // This is only needed if we know that the operator has a different set of fallthroughs defined for some backends.
236   std::array<DispatchKeySet, num_backends> nonFallthroughKeysPerBackend_;
237   // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast path),
238   // or if we need to fall back to the slower path and check nonFallthroughKeysPerBackend_
239   bool requiresBitsetPerBackend_;
240 };
241 
242 }
243