1 #include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h> 2 3 #include <torch/csrc/jit/runtime/simple_graph_executor_impl.h> 4 #include <mutex> 5 #include <optional> 6 7 namespace torch::jit { 8 SimpleGraphExecutorImpl(const std::shared_ptr<Graph> & graph,std::string function_name)9SimpleGraphExecutorImpl::SimpleGraphExecutorImpl( 10 const std::shared_ptr<Graph>& graph, 11 std::string function_name) 12 : GraphExecutorImplBase(graph, std::move(function_name)) {} 13 getPlanFor(Stack & stack,std::optional<size_t> remaining_bailout_depth)14const ExecutionPlan& SimpleGraphExecutorImpl::getPlanFor( 15 Stack& stack, 16 std::optional<size_t> remaining_bailout_depth) { 17 std::lock_guard<std::mutex> lock(compile_mutex); 18 19 // IMPORTANT: This is a hot path of calling a torchscript function. Try not to 20 // add any code above this. 21 if (execution_plan_) { 22 return *execution_plan_; 23 } 24 auto copy = graph->copy(); 25 runNooptPassPipeline(copy); 26 execution_plan_ = ExecutionPlan(copy, function_name_); 27 28 return *execution_plan_; 29 } 30 getDebugState()31GraphExecutorState SimpleGraphExecutorImpl::getDebugState() { 32 GraphExecutorState state; 33 TORCH_INTERNAL_ASSERT(execution_plan_); 34 state.graph = execution_plan_->graph.get(); 35 auto opt_plan = *execution_plan_; 36 state.execution_plans.emplace(ArgumentSpec{0, 0}, opt_plan); 37 return state; 38 } 39 40 } // namespace torch::jit 41