1 #pragma once 2 3 #include <ATen/core/function_schema.h> 4 #include <c10/util/Metaprogramming.h> 5 #include <c10/util/flat_hash_map.h> 6 #include <c10/core/DispatchKey.h> 7 #include <c10/core/PyHandleCache.h> 8 #include <c10/core/SafePyObject.h> 9 #include <ATen/core/ivalue.h> 10 #include <ATen/core/boxing/KernelFunction.h> 11 #include <ATen/core/dispatch/DispatchKeyExtractor.h> 12 13 #include <ATen/core/dispatch/OperatorOptions.h> 14 #include <ATen/core/dispatch/CppSignature.h> 15 #include <ATen/core/dispatch/RegistrationHandleRAII.h> 16 #include <ATen/core/enum_tag.h> 17 18 #include <optional> 19 #include <array> 20 #include <list> 21 22 #ifdef C10_MOBILE 23 #define C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY 24 #endif 25 26 namespace c10 { 27 28 class Dispatcher; 29 30 namespace impl { 31 32 // This data structure represents a kernel that was registered to us from a 33 // user. Unlike KernelFunction, AnnotatedKernel contains some extra metadata 34 // about the kernel that isn't necessary for actual dispatching (this is why 35 // we don't put AnnotatedKernel in the actual DispatchTable), but is useful for 36 // giving good error messages. 37 struct AnnotatedKernel final { AnnotatedKernelfinal38 AnnotatedKernel(KernelFunction k, std::unique_ptr<FunctionSchema> s, std::string d) 39 : kernel(std::move(k)) 40 , inferred_function_schema(std::move(s)) 41 , debug(std::move(d)) 42 {} 43 AnnotatedKernel() = default; 44 KernelFunction kernel; 45 std::unique_ptr<FunctionSchema> inferred_function_schema; 46 // A little debug string to help us identify the kernel in question. 47 // Most importantly it records the TORCH_LIBRARY block that did the 48 // registration. 49 std::string debug; 50 }; 51 52 // This data structure represents operator schema, with metadata specifying 53 // where the registration of this schema occurred 54 struct AnnotatedSchema final { AnnotatedSchemafinal55 AnnotatedSchema(FunctionSchema s, std::string d) 56 : schema(std::move(s)) 57 , debug(std::move(d)) 58 {} 59 FunctionSchema schema; 60 std::string debug; 61 }; 62 63 // Internal data structure that records information about a specific operator. 64 // It's not part of the public API; typically, users will interact with 65 // OperatorHandle instead. 66 // 67 // Concurrent writes to OperatorEntry are protected by the GLOBAL Dispatcher 68 // lock (this is important because some methods in OperatorEntry access 69 // dispatcher state) 70 class TORCH_API OperatorEntry final { 71 public: 72 explicit OperatorEntry(OperatorName&& operator_name); 73 74 OperatorEntry(const OperatorEntry&) = delete; 75 OperatorEntry(OperatorEntry&&) noexcept = delete; 76 OperatorEntry& operator=(const OperatorEntry&) = delete; 77 OperatorEntry& operator=(OperatorEntry&&) noexcept = delete; 78 schema()79 const FunctionSchema& schema() const { 80 TORCH_INTERNAL_ASSERT(schema_.has_value(), "Tried to access the schema for ", name_, " which doesn't have a schema registered yet"); 81 return schema_->schema; 82 } debug()83 const std::string& debug() const { 84 TORCH_INTERNAL_ASSERT(schema_.has_value()); 85 return schema_->debug; 86 } hasSchema()87 bool hasSchema() const { 88 return schema_.has_value(); 89 } 90 isObserved()91 bool isObserved() const { 92 return is_observed_; 93 } 94 95 // We may allocate an OperatorEntry for an operator even when we don't 96 // have a schema. When we receive the schema registration, we post 97 // facto register a schema. 98 // 99 // NB: registerSchema/deregisterSchema are not idempotent; if you 100 // attempt to register a schema when one is already present or vice 101 // versa that is an error. (Refcounting for the registrations is 102 // handled in the OperatorHandle in Dispatcher) 103 void registerSchema(FunctionSchema&&, std::string&& debug, std::vector<at::Tag> tags = {}); 104 void deregisterSchema(); 105 operator_name()106 const OperatorName& operator_name() const { 107 return name_; 108 } 109 110 #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY 111 using AnnotatedKernelContainer = std::array<AnnotatedKernel, 1>; 112 #else 113 using AnnotatedKernelContainer = std::list<AnnotatedKernel>; 114 #endif 115 using AnnotatedKernelContainerIterator = AnnotatedKernelContainer::iterator; 116 117 // Why are kernels and fallback asymmetric? It has to do with ownership. 118 // Kernels and the computed dispatch tables for them are canonically 119 // owned by OperatorEntry, but backend fallbacks are specified once 120 // and apply for all operators, so they should be owned by Dispatcher. 121 // However, the registration of a backend fallback affects the 122 // state of the computed dispatch table, so when a backend fallback 123 // is updated, we need to update the operator tables too. Thus, 124 // registerKernel is the mechanism by which we give kernels to 125 // operator entry to own (and update dispatch table), but we only 126 // need a non-owning mechanism to update fallback. 127 128 // Precondition: Dispatcher::mutex_ is held 129 // Postcondition: caller is responsible for disposing of the kernel 130 AnnotatedKernelContainerIterator registerKernel( 131 const Dispatcher& dispatcher, 132 std::optional<DispatchKey> dispatch_key, 133 KernelFunction kernel, 134 std::optional<CppSignature> cpp_signature, 135 std::unique_ptr<FunctionSchema> inferred_function_schema, 136 std::string debug 137 ); 138 139 // Precondition: Dispatcher::mutex_ is held 140 void deregisterKernel_( 141 const Dispatcher& dispatcher, 142 std::optional<DispatchKey> dispatch_key, 143 AnnotatedKernelContainerIterator kernel 144 ); 145 146 // Precondition: Dispatcher::mutex_ is held 147 void updateFallback( 148 const Dispatcher& dispatcher, 149 DispatchKey dispatch_key 150 ); 151 152 // Precondition: Dispatcher::mutex_ is held updateSchemaAliasAnalysis(AliasAnalysisKind a)153 void updateSchemaAliasAnalysis(AliasAnalysisKind a) { 154 TORCH_INTERNAL_ASSERT(schema_.has_value()); 155 schema_->schema.setAliasAnalysis(a); 156 } 157 158 std::string dumpComputedTable() const; 159 std::string dumpState() const; 160 void checkInvariants() const; 161 dispatchKeyExtractor()162 const DispatchKeyExtractor& dispatchKeyExtractor() const { return dispatchKeyExtractor_; } 163 164 // Asserts that the given FuncType is correct for calling this operator in an unboxed way. 165 template<class FuncType> assertSignatureIsCorrect()166 inline void assertSignatureIsCorrect() { 167 assertSignatureIsCorrect(CppSignature::make<FuncType>(), fn_has_symint<FuncType>::value); 168 } 169 170 void assertSignatureIsCorrect(const CppSignature& call_signature, bool has_symint) const; 171 172 [[noreturn]] void reportError(DispatchKey dispatchKey) const; 173 lookup(DispatchKeySet ks)174 const KernelFunction& lookup(DispatchKeySet ks) const { 175 const auto idx = ks.getDispatchTableIndexForDispatchKeySet(); 176 if (C10_UNLIKELY(idx == -1)) { 177 reportError(ks.highestPriorityTypeId()); 178 } 179 const auto& kernel = dispatchTable_[idx]; 180 // A valid kernel *always* has a boxed kernel and *may* have an 181 // unboxed kernel. However, we typically do unboxed calls in at:: 182 // APIs, where the kernel 1) will very likely be valid and 2) 183 // should have an unboxed kernel. Checking the unboxed kernel 184 // first will allow us to avoid touching the boxed kernel at all 185 // in the common case. 186 if (C10_UNLIKELY(!kernel.isValidUnboxed())) { 187 if (!kernel.isValid()) { 188 reportError(ks.highestPriorityTypeId()); 189 } 190 } 191 return kernel; 192 } 193 194 std::string listAllDispatchKeys() const; 195 196 // Returns true if kernel_ has entry for any key in ks. 197 // 198 // Invariant: There are no alias keys in the passed-in dispatch key set. 199 // Note [No Alias Keys in DispatchKeySet] 200 // Alias keys should be checked using `hasKernelForDispatchKey` 201 // Alias keys shouldn't go inside of a DispatchKeySet, since they can technically 202 // have a value > 63 (causing overflow). 203 bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const; 204 // Returns true if kernel_ has entry for a particular key. 205 bool hasKernelForDispatchKey(DispatchKey k) const; 206 // Retrieves the kernel entry at a particular key. Symmetric with 207 // hasKernelForDispatchKey. To get the AnnotatedKernel, see 208 // getKernelForDispatchKey (private) 209 const KernelFunction& kernelForDispatchKey(DispatchKey k) const; 210 // Returns true if the "computed table" has an entry for a particular key. 211 bool hasComputedKernelForDispatchKey(DispatchKey k) const; 212 // Returns all the operator tags added at the time of registration 213 const std::vector<at::Tag>& getTags() const; 214 void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback); 215 216 template <typename F> getPythonOp(PyInterpreter * self_interpreter,F slow_accessor)217 PyObject* getPythonOp(PyInterpreter* self_interpreter, F slow_accessor) const { 218 return py_cache_.ptr_or(self_interpreter, slow_accessor); 219 } 220 221 private: 222 223 OperatorName name_; 224 std::optional<AnnotatedSchema> schema_; 225 #ifndef C10_MOBILE 226 std::vector<at::Tag> tags_; 227 #endif 228 std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_; 229 DispatchKeyExtractor dispatchKeyExtractor_; 230 // Pointer to the torch.ops.ns.op.overload object for speed 231 c10::PyHandleCache py_cache_; 232 233 // kernels_ stores all registered kernels for the corresponding dispatch key 234 // and catchAllKernels_ stores the catch-all kernels. 235 // If an operator library gets loaded that overwrites an already existing kernel, 236 // both kernels will be in that list but only the newer one will be in 237 // dispatchTable. If any of the kernels go away (say the library gets 238 // unloaded), we remove the kernel from this list and update the 239 // dispatchTable if necessary. 240 // Kernels in the list are ordered by registration time descendingly, 241 // newer registrations are before older registrations. 242 // We do not combine dispatchTable and kernels into one hash map because 243 // kernels is a larger data structure and accessed quite infrequently 244 // while dispatchTable is accessed often and should be kept small to fit 245 // into CPU caches. 246 // Invariants: 247 // - dispatchTable[dispatch_key] == kernels_[dispatch_key].front() 248 // - dispatchTable[dispatch_key] does not exist if and only if 249 // kernels_[dispatch_key] does not exist 250 // - If kernels_[dispatch_key] exists, then it has elements. 251 // It is never an empty list. 252 // 253 // Why do we do that? 254 // ----- 255 // We mostly do this to enable Jupyter notebooks where a cell registering 256 // a kernel could be executed multiple times and the later execution 257 // should overwrite the earlier one. Note that this still fails when the 258 // function schema changed between the executions, but it works as long 259 // as the function schema didn't change. A better solution would be to 260 // unload the old extension library from the Jupyter cell when the cell is 261 // re-executed and then only allow one kernel here, i.e. error if a kernel 262 // is already registered, but that's a lot of effort to implement and 263 // currently not high-pri. 264 ska::flat_hash_map<DispatchKey, 265 #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY 266 // On mobile, we needn't worry about Jupyter notebooks. 267 std::array<AnnotatedKernel, 1> 268 #else 269 std::list<AnnotatedKernel> 270 #endif 271 > kernels_; 272 273 const AnnotatedKernel& missingKernel() const; 274 const AnnotatedKernel& ambiguousAutogradOtherKernel() const; 275 276 // cpp_signature_ stores function signature if any of 277 // the kernels was created in a way that allowed us to know the function 278 // signature (i.e. by supplying an unboxed C++ kernel function). 279 // If this is set, it will be used to check that future kernel 280 // registrations match and it will be used in unboxed function calls 281 // to verify their arguments against the known function signature. 282 struct CppSignatureWithDebug { 283 CppSignature signature; 284 std::string debug; 285 std::optional<DispatchKey> dispatch_key; 286 }; 287 std::optional<CppSignatureWithDebug> cpp_signature_; 288 std::optional<CppSignatureWithDebug> sym_cpp_signature_; 289 290 // A Python custom error handler for OperatorEntry::reportError 291 std::unique_ptr<c10::SafePyObject> report_error_callback_; 292 293 // Whether this operator needs to be observed with RecordFunction 294 const bool is_observed_; 295 296 [[noreturn]] void reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const; 297 const KernelFunction& computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const; 298 std::pair<const AnnotatedKernel&, const char*> computeDispatchTableEntryWithDebug( 299 const c10::Dispatcher& dispatcher, DispatchKey dispatch_key 300 ) const; 301 // This function re-establishes the invariant that dispatchTable 302 // contains the front element from the kernels list for a given runtime dispatch key. 303 void updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key); 304 // Like above, but also handles alias dispatch keys. 305 void updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key); 306 // Like above, but for ALL entries in the dispatch table. 307 void updateDispatchTableFull_(const c10::Dispatcher& dispatcher); 308 // Retrieves a pointer to AnnotatedKernel at kernels_.at(dispatch_key).front(). 309 const AnnotatedKernel* getKernelForDispatchKey(DispatchKey dispatch_key) const; 310 }; 311 312 } // namespace impl 313 } // namespace c10 314