1 #pragma once 2 #include <ATen/functorch/Interpreter.h> 3 4 namespace at::functorch { 5 6 // These are the interpreters for our AD transforms 7 // (grad, vjp and jvp). 8 // See NOTE: [functorch interpreter stack] for more details. 9 10 struct TORCH_API GradInterpreterPtr { GradInterpreterPtrGradInterpreterPtr11 explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); } keyGradInterpreterPtr12 TransformType key() const { return base_->key(); } levelGradInterpreterPtr13 int64_t level() const { return base_->level(); } 14 void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); 15 void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); prevGradModeGradInterpreterPtr16 bool prevGradMode() const { 17 return std::get<GradInterpreterMeta>(base_->meta()).prevGradMode_; 18 } 19 Tensor lift(const Tensor& tensor) const; 20 private: 21 const Interpreter* base_; 22 }; 23 24 struct TORCH_API JvpInterpreterPtr { JvpInterpreterPtrJvpInterpreterPtr25 explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); } keyJvpInterpreterPtr26 TransformType key() const { return base_->key(); } levelJvpInterpreterPtr27 int64_t level() const { return base_->level(); } 28 void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); 29 void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); prevFwdGradModeJvpInterpreterPtr30 bool prevFwdGradMode() const { 31 return std::get<JvpInterpreterMeta>(base_->meta()).prevFwdGradMode_; 32 } 33 Tensor lift(const Tensor& tensor) const; 34 private: 35 const Interpreter* base_; 36 }; 37 38 } // namespace at::functorch 39