xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/ADInterpreters.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/functorch/Interpreter.h>
3 
4 namespace at::functorch {
5 
6 // These are the interpreters for our AD transforms
7 // (grad, vjp and jvp).
8 // See NOTE: [functorch interpreter stack] for more details.
9 
10 struct TORCH_API GradInterpreterPtr {
GradInterpreterPtrGradInterpreterPtr11   explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); }
keyGradInterpreterPtr12   TransformType key() const { return base_->key(); }
levelGradInterpreterPtr13   int64_t level() const { return base_->level(); }
14   void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
15   void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
prevGradModeGradInterpreterPtr16   bool prevGradMode() const {
17     return std::get<GradInterpreterMeta>(base_->meta()).prevGradMode_;
18   }
19   Tensor lift(const Tensor& tensor) const;
20  private:
21   const Interpreter* base_;
22 };
23 
24 struct TORCH_API JvpInterpreterPtr {
JvpInterpreterPtrJvpInterpreterPtr25   explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); }
keyJvpInterpreterPtr26   TransformType key() const { return base_->key(); }
levelJvpInterpreterPtr27   int64_t level() const { return base_->level(); }
28   void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
29   void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
prevFwdGradModeJvpInterpreterPtr30   bool prevFwdGradMode() const {
31     return std::get<JvpInterpreterMeta>(base_->meta()).prevFwdGradMode_;
32   }
33   Tensor lift(const Tensor& tensor) const;
34  private:
35   const Interpreter* base_;
36 };
37 
38 } // namespace at::functorch
39