xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/simple_graph_executor_impl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h>
2 
3 #include <torch/csrc/jit/runtime/simple_graph_executor_impl.h>
4 #include <mutex>
5 #include <optional>
6 
7 namespace torch::jit {
8 
SimpleGraphExecutorImpl(const std::shared_ptr<Graph> & graph,std::string function_name)9 SimpleGraphExecutorImpl::SimpleGraphExecutorImpl(
10     const std::shared_ptr<Graph>& graph,
11     std::string function_name)
12     : GraphExecutorImplBase(graph, std::move(function_name)) {}
13 
getPlanFor(Stack & stack,std::optional<size_t> remaining_bailout_depth)14 const ExecutionPlan& SimpleGraphExecutorImpl::getPlanFor(
15     Stack& stack,
16     std::optional<size_t> remaining_bailout_depth) {
17   std::lock_guard<std::mutex> lock(compile_mutex);
18 
19   // IMPORTANT: This is a hot path of calling a torchscript function. Try not to
20   // add any code above this.
21   if (execution_plan_) {
22     return *execution_plan_;
23   }
24   auto copy = graph->copy();
25   runNooptPassPipeline(copy);
26   execution_plan_ = ExecutionPlan(copy, function_name_);
27 
28   return *execution_plan_;
29 }
30 
getDebugState()31 GraphExecutorState SimpleGraphExecutorImpl::getDebugState() {
32   GraphExecutorState state;
33   TORCH_INTERNAL_ASSERT(execution_plan_);
34   state.graph = execution_plan_->graph.get();
35   auto opt_plan = *execution_plan_;
36   state.execution_plans.emplace(ArgumentSpec{0, 0}, opt_plan);
37   return state;
38 }
39 
40 } // namespace torch::jit
41