xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/FunctionalizeInterpreter.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/functorch/Interpreter.h>
3 
4 namespace at::functorch {
5 
6 // This is the interpreter that handles the functionalize() transform.
7 // See NOTE: [functorch interpreter stack] for more details.
8 
9 struct FunctionalizeInterpreterPtr {
FunctionalizeInterpreterPtrFunctionalizeInterpreterPtr10   explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); }
keyFunctionalizeInterpreterPtr11   TransformType key() const { return base_->key(); }
levelFunctionalizeInterpreterPtr12   int64_t level() const { return base_->level(); }
13   void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
14   void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
functionalizeAddBackViewsFunctionalizeInterpreterPtr15   bool functionalizeAddBackViews() const {
16     return std::get<FunctionalizeInterpreterMeta>(base_->meta()).functionalizeAddBackViews_;
17   }
18  private:
19   const Interpreter* base_;
20 };
21 
22 } // namespace at::functorch
23