1 #pragma once
2
3 #include <cstdint>
4 #include <ATen/core/function_schema.h>
5 #include <ATen/core/jit_type.h>
6 #include <c10/util/Bitset.h>
7 #include <c10/core/DispatchKeySet.h>
8 #include <c10/util/irange.h>
9 #include <ATen/core/Variadic.h>
10 #include <ATen/core/stack.h>
11
12 namespace c10 {
13
14 namespace impl {
15
16 // Take a DispatchKeySet for a Tensor and determine what the actual dispatch
17 // DispatchKey should be, taking into account TLS, and skipping backends which
18 // fall through.
19 //
20 // Unlike Tensor::key_set(), the value of this on a tensor can change depending
21 // on TLS.
22 //
23 // NB: If there is no valid dispatch key, this will return Undefined
computeDispatchKeySet(DispatchKeySet ks,DispatchKeySet key_mask)24 inline DispatchKeySet computeDispatchKeySet(
25 DispatchKeySet ks,
26 // The key mask lets us eliminate (by zero entries) keys which should not
27 // be considered for dispatch. There are two cases when we use this:
28 //
29 // - If an operator's dispatch table contains a fallthrough entry, we
30 // should bypass it entirely when finding the key
31 // - If a user invokes with redispatch, the mask lets us
32 // zero out the key the user asked us to stop.
33 //
34 // These excluded backends are NOT tracked in the TLS, but must be applied
35 // AFTER TLS (since the backend may have been introduced for consideration
36 // by the included TLS), which is why you have to pass them in to this
37 // function (as opposed to just applying it to the input 'ks').
38 DispatchKeySet key_mask
39 ) {
40 c10::impl::LocalDispatchKeySet local = c10::impl::tls_local_dispatch_key_set();
41 // TODO: It's a bit irritating that we have to do logical ORs here, it would
42 // be nice to only do one. Can always_included be folded into the TLS? Well,
43 // it's a bit troublesome, because fastpath TLS access requires the type of
44 // the TLS in question to be zero-initialized, so you don't actually win
45 // anything in that case.
46 return (((ks | local.included_) - local.excluded_) & key_mask);
47 }
48
49 }
50
51 namespace detail {
52 // A small gadget to extract the DispatchKeySet from types which are known
53 // to have it. Used to extract dispatch keys from unboxed calls.
54 struct MultiDispatchKeySet : at::IterArgs<MultiDispatchKeySet> {
55 DispatchKeySet ts;
operatorMultiDispatchKeySet56 void operator()(const at::Tensor& x) {
57 ts = ts | x.key_set();
58 }
operatorMultiDispatchKeySet59 void operator()(const std::optional<at::Tensor>& x) {
60 if (x.has_value()) {
61 ts = ts | x->key_set();
62 }
63 }
operatorMultiDispatchKeySet64 void operator()(at::ArrayRef<at::Tensor> xs) {
65 for (const auto& x : xs) {
66 ts = ts | x.key_set();
67 }
68 }
69 // Tensor?[] translates to this case.
operatorMultiDispatchKeySet70 void operator()(const c10::List<std::optional<at::Tensor>>& xs) {
71 for (std::optional<at::Tensor> x : xs) {
72 if (x.has_value()) {
73 ts = ts | x.value().key_set();
74 }
75 }
76 }
77 // Structured Tensor[] translates to this case
operatorMultiDispatchKeySet78 void operator()(const at::ITensorListRef& xs) {
79 for (const auto& x : xs) {
80 ts = ts | x.key_set();
81 }
82 }
operatorMultiDispatchKeySet83 [[noreturn]] void operator()(at::ArrayRef<std::optional<at::Tensor>>) {
84 // Just checking that the handling of Tensor?[] didn't change.
85 TORCH_INTERNAL_ASSERT(false);
86 }
operatorMultiDispatchKeySet87 void operator()(const at::Generator& gen) {
88 if (gen.defined()) {
89 ts = ts | gen.key_set();
90 }
91 }
operatorMultiDispatchKeySet92 void operator()(const std::optional<at::Generator>& gen) {
93 if (gen.has_value() && gen->defined()) {
94 ts = ts | gen->key_set();
95 }
96 }
97 template <typename T>
operatorMultiDispatchKeySet98 void operator()(const T&) {
99 // do nothing
100 }
101 };
102
103 // NB: take by const reference (Don't do universal forwarding here! You
104 // don't want to move into this function!)
105 template <typename... Args>
multi_dispatch_key_set(const Args &...args)106 DispatchKeySet multi_dispatch_key_set(const Args&... args) {
107 return MultiDispatchKeySet().apply(args...).ts;
108 }
109 }
110
111 /**
112 * An instance of DispatchKeyExtractor knows how to get a dispatch key given
113 * a list of arguments for an operator call.
114 *
115 * The instance is specific for a certain operator as:
116 * - In boxed dispatch, different operators have different ways to extract
117 * the dispatch key (e.g. different numbers of arguments), and we precompute
118 * the stack locations we should look at; and
119 * - In all dispatch, some backends should be excluded from dispatch because
120 * they have been registered as fallthrough. The set of excluded backends
121 * varies from operator, as some operators may have overridden the
122 * fallthrough with custom behavior.
123 *
124 * Note - this should maintain identical impl to the py dispatcher key extraction logic
125 * at pytorch/torch/dispatcher.py
126 */
127 struct TORCH_API DispatchKeyExtractor final {
128 public:
makefinal129 static DispatchKeyExtractor make(const FunctionSchema& schema) {
130 return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema));
131 }
132
makeUninitializedfinal133 static DispatchKeyExtractor makeUninitialized() {
134 return DispatchKeyExtractor(c10::utils::bitset());
135 }
136
registerSchemafinal137 void registerSchema(const FunctionSchema& schema) {
138 TORCH_INTERNAL_ASSERT(dispatch_arg_indices_reverse_.is_entirely_unset());
139 dispatch_arg_indices_reverse_ = makeBitsetForDispatchArgs(schema);
140 }
deregisterSchemafinal141 void deregisterSchema() {
142 dispatch_arg_indices_reverse_ = c10::utils::bitset();
143 }
144
getDispatchKeySetBoxedfinal145 DispatchKeySet getDispatchKeySetBoxed(const torch::jit::Stack* stack) const {
146 DispatchKeySet ks;
147 dispatch_arg_indices_reverse_.for_each_set_bit([&] (size_t reverse_arg_index) {
148 const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1);
149 if (C10_LIKELY(ivalue.isTensor())) {
150 // NB: Take care not to introduce a refcount bump (there's
151 // no safe toTensorRef method, alas)
152 ks = ks | ivalue.unsafeToTensorImpl()->key_set();
153 } else if (C10_UNLIKELY(ivalue.isTensorList())) {
154 for (const at::Tensor& tensor : ivalue.toTensorList()) {
155 ks = ks | tensor.key_set();
156 }
157 }
158 // Tensor?[] translates to a c10::List<IValue> so we need to peek inside
159 else if (C10_UNLIKELY(ivalue.isList())) {
160 for (const auto& elt : ivalue.toListRef()) {
161 if (elt.isTensor()) {
162 ks = ks | elt.toTensor().key_set();
163 }
164 }
165 }
166 });
167 // Keys that are fallthrough should be skipped
168 if (requiresBitsetPerBackend_) {
169 auto backend_idx = ks.getBackendIndex();
170 return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
171 } else {
172 return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
173 }
174 }
175
176 template<class... Args>
getDispatchKeySetUnboxedfinal177 DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const {
178 auto ks = detail::multi_dispatch_key_set(args...);
179 // Keys that are fallthrough should be skipped
180 if (requiresBitsetPerBackend_) {
181 auto backend_idx = ks.getBackendIndex();
182 return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
183 } else {
184 return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
185 }
186 }
187
188 void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough);
189
190 std::string dumpState() const;
191 void checkInvariants(const FunctionSchema& schema) const;
192
193 private:
makeBitsetForDispatchArgsfinal194 static c10::utils::bitset makeBitsetForDispatchArgs(const FunctionSchema& schema) {
195 TORCH_CHECK(schema.arguments().size() <= c10::utils::bitset::NUM_BITS(),
196 "The function schema has ", schema.arguments().size(),
197 " arguments but this PyTorch build only supports ", c10::utils::bitset::NUM_BITS());
198 c10::utils::bitset dispatch_arg_indices_reverse;
199 for (const auto index : c10::irange(schema.arguments().size())) {
200 if (schema.arguments()[index].type()->isSubtypeOf(*TensorType::get()) ||
201 schema.arguments()[index].type()->isSubtypeOf(
202 *ListType::ofTensors()) ||
203 schema.arguments()[index].type()->isSubtypeOf(
204 *ListType::ofOptionalTensors()) ||
205 schema.arguments()[index].type()->isSubtypeOf(
206 *OptionalType::ofTensor())) {
207 dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index);
208 }
209 }
210 return dispatch_arg_indices_reverse;
211 }
212
DispatchKeyExtractorfinal213 explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
214 : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse)
215 , nonFallthroughKeys_(DispatchKeySet::FULL)
216 , requiresBitsetPerBackend_(false) {
217 for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
218 nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL;
219 }
220 }
221
222 // this is a bitset that has ones for each argument index which has to be
223 // considered for dispatch. This avoids having to iterate over the stack
224 // to find all the tensors. The bits are stored in reverse order, i.e.
225 // dispatch_arg_indices_reverse_[i] == true, then the i-th argument from
226 // the top of the stack (i.e. the i-th last argument of the function)
227 // is relevant for dispatch.
228 // dispatch_arg_indices_reverse_ is allowed to have zero bits set; that just means you must do the
229 // fallthrough
230 c10::utils::bitset dispatch_arg_indices_reverse_;
231
232 // Set of functionality keys for which the operator does NOT have fallthrough kernel.
233 DispatchKeySet nonFallthroughKeys_;
234 // Set of functionality keys for which the operator does NOT have fallthrough kernel, defined PER BACKEND.
235 // This is only needed if we know that the operator has a different set of fallthroughs defined for some backends.
236 std::array<DispatchKeySet, num_backends> nonFallthroughKeysPerBackend_;
237 // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast path),
238 // or if we need to fall back to the slower path and check nonFallthroughKeysPerBackend_
239 bool requiresBitsetPerBackend_;
240 };
241
242 }
243