1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <ATen/core/function_schema.h> 5 #include <c10/macros/Export.h> 6 7 // NOTE: [Jit Decomposition Interface] 8 // 9 // For some context of why we need this at all, see NOTE: [forward-mode AD 10 // decompositions mechanism] 11 // 12 // Introducing that mechanism from the NOTE is problematic because: 13 // - it relies on TorchScript, so now VariableTypeX.cpp depends on TorchScript. 14 // - there exist internal builds like lite_trainer, which depend on VariableType 15 // but do not depend on TorchScript. 16 // 17 // For internal builds like lite_trainer builds to pass, and for OSS builds that 18 // do depend on TorchScript to still support the forward AD decomp mechanism, we 19 // implement a PImpl pattern to avoid a static dependency in favor of a dynamic 20 // one 21 // - during static initialization time, if the library is built with TorchScript 22 // setJitDecompImpl is called in decomposition_registry.cpp setting a global 23 // ptr to the impl 24 // - when the program is run,if getJitDecompImpl returns a non null ptr, we can 25 // carry on normally, otherwise we gracefully error out 26 // 27 // For extra context, see VariableHooksInterface.h, where a similar technique 28 // is used 29 30 namespace torch::autograd::impl { 31 32 struct TORCH_API JitDecompInterface { 33 virtual ~JitDecompInterface() = default; 34 virtual bool has_jit_decomposition( 35 const c10::FunctionSchema& schema) const = 0; 36 virtual void run_jit_decomposition( 37 const c10::OperatorHandle& op, 38 jit::Stack* stack) const = 0; 39 }; 40 41 TORCH_API void setJitDecompImpl(JitDecompInterface* impl); 42 TORCH_API JitDecompInterface* getJitDecompImpl(); 43 44 struct TORCH_API JitDecompRegisterer { JitDecompRegistererJitDecompRegisterer45 explicit JitDecompRegisterer(JitDecompInterface* impl) { 46 setJitDecompImpl(impl); 47 } 48 }; 49 50 } // namespace torch::autograd::impl 51