xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/jit_decomp_interface.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/core/function_schema.h>
5 #include <c10/macros/Export.h>
6 
7 // NOTE: [Jit Decomposition Interface]
8 //
9 // For some context of why we need this at all, see NOTE: [forward-mode AD
10 // decompositions mechanism]
11 //
12 // Introducing that mechanism from the NOTE is problematic because:
13 // - it relies on TorchScript, so now VariableTypeX.cpp depends on TorchScript.
14 // - there exist internal builds like lite_trainer, which depend on VariableType
15 //   but do not depend on TorchScript.
16 //
17 // For internal builds like lite_trainer builds to pass, and for OSS builds that
18 // do depend on TorchScript to still support the forward AD decomp mechanism, we
19 // implement a PImpl pattern to avoid a static dependency in favor of a dynamic
20 // one
21 // - during static initialization time, if the library is built with TorchScript
22 //   setJitDecompImpl is called in decomposition_registry.cpp setting a global
23 //   ptr to the impl
24 // - when the program is run,if getJitDecompImpl returns a non null ptr, we can
25 //   carry on normally, otherwise we gracefully error out
26 //
27 // For extra context, see VariableHooksInterface.h, where a similar technique
28 // is used
29 
30 namespace torch::autograd::impl {
31 
32 struct TORCH_API JitDecompInterface {
33   virtual ~JitDecompInterface() = default;
34   virtual bool has_jit_decomposition(
35       const c10::FunctionSchema& schema) const = 0;
36   virtual void run_jit_decomposition(
37       const c10::OperatorHandle& op,
38       jit::Stack* stack) const = 0;
39 };
40 
41 TORCH_API void setJitDecompImpl(JitDecompInterface* impl);
42 TORCH_API JitDecompInterface* getJitDecompImpl();
43 
44 struct TORCH_API JitDecompRegisterer {
JitDecompRegistererJitDecompRegisterer45   explicit JitDecompRegisterer(JitDecompInterface* impl) {
46     setJitDecompImpl(impl);
47   }
48 };
49 
50 } // namespace torch::autograd::impl
51