1 #pragma once 2 #include <torch/csrc/jit/runtime/graph_executor.h> 3 4 #include <ATen/core/ivalue.h> 5 #include <c10/util/Exception.h> 6 #include <torch/csrc/autograd/grad_mode.h> 7 #include <torch/csrc/jit/frontend/tracer.h> 8 #include <torch/csrc/jit/ir/ir.h> 9 #include <torch/csrc/jit/passes/shape_analysis.h> 10 #include <torch/csrc/jit/resource_guard.h> 11 #include <torch/csrc/jit/runtime/argument_spec.h> 12 #include <torch/csrc/jit/runtime/autodiff.h> 13 #include <torch/csrc/jit/runtime/custom_operator.h> 14 #include <torch/csrc/jit/runtime/interpreter.h> 15 #include <torch/csrc/jit/runtime/profiling_record.h> 16 17 #include <torch/csrc/autograd/edge.h> 18 #include <torch/csrc/autograd/function.h> 19 #include <torch/csrc/jit/frontend/ir_emitter.h> 20 #include <torch/csrc/jit/runtime/logging.h> 21 22 #include <cstdint> 23 #include <iterator> 24 #include <memory> 25 #include <mutex> 26 #include <unordered_map> 27 #include <utility> 28 #include <vector> 29 30 namespace torch::jit { 31 32 void packGradient(const Gradient& gradient, Node* dnode); 33 bool needsGradient(const std::shared_ptr<const Graph>& graph); 34 void runOptimization( 35 std::shared_ptr<Graph>& graph, 36 bool unroll_non_constant_loops = true, 37 bool const_prop_user_classes = true); 38 void runNondiffOptimization( 39 std::shared_ptr<Graph>& graph, 40 bool strict_fuser_check = false); 41 void debugSetAutodiffSubgraphInlining(bool state); 42 bool TORCH_API getAutodiffSubgraphInlining(); 43 44 void debugSetFusionGroupInlining(bool state); 45 bool getFusionGroupInlining(); 46 47 // Tunable parameters for deciding when to create/keep subgraphs of 48 // differentiable code 49 const size_t autodiffSubgraphNodeThreshold = 2; 50 const size_t autodiffSubgraphInlineThreshold = 5; 51 52 // a Graph can be created via tracing, or via a language-based frontend 53 // GraphExecutor runs it. It can run the same graph on many different sizes 54 // and different requires_grad states, and handles specializations for each 55 // situation. GraphExecutor is completely unaware of tracing or module 56 // parameters to keep the tracing concerns separated. 57 struct GraphExecutorImplBase { prepareGraphGraphExecutorImplBase58 static std::shared_ptr<Graph> prepareGraph( 59 const std::shared_ptr<Graph>& graph) { 60 auto copy = graph->copy(); 61 EraseShapeInformation(copy); 62 return copy; 63 } 64 GraphExecutorImplBaseGraphExecutorImplBase65 GraphExecutorImplBase( 66 const std::shared_ptr<Graph>& graph, 67 std::string function_name) 68 : graph(prepareGraph(graph)), 69 function_name_(std::move(function_name)), 70 num_inputs(this->graph->inputs().size()), 71 num_outputs(this->graph->outputs().size()) {} 72 73 // entry point where execution begins 74 void run(Stack& stack); 75 c10::intrusive_ptr<Future> runAsync( 76 Stack& stack, 77 TaskLauncher taskLauncher = at::launch); 78 79 virtual const ExecutionPlan& getPlanFor( 80 Stack& stack, 81 std::optional<size_t> remaining_bailout_depth = std::nullopt) = 0; 82 virtual GraphExecutorState getDebugState() = 0; 83 virtual ~GraphExecutorImplBase() = default; 84 isOptimizedGraphExecutorImplBase85 virtual bool isOptimized() const { 86 return false; 87 } 88 89 protected: 90 friend struct GraphExecutor; 91 92 // The unoptimized starting graph. This field is effectively const, but we 93 // can't make it so because Graph::copy() is not const (and making it const is 94 // not that easy at this point). 95 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 96 std::shared_ptr<Graph> graph; 97 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 98 std::string function_name_; 99 100 // If false, we'll run the graph as we get it, without any optimizations. 101 // Useful for debugging. 102 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 103 const size_t num_inputs; 104 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 105 const size_t num_outputs; 106 107 // GraphExecutors can be accessed from multiple threads, so this thread needs 108 // to be held every time we access the fallback or plan_cache. 109 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 110 std::mutex compile_mutex; 111 }; 112 113 } // namespace torch::jit 114