xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/graph_executor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/graph_executor.h>
2 
3 #include <ATen/core/ivalue.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/autograd/grad_mode.h>
7 #include <torch/csrc/jit/frontend/tracer.h>
8 #include <torch/csrc/jit/ir/ir.h>
9 #include <torch/csrc/jit/jit_log.h>
10 #include <torch/csrc/jit/passes/batch_mm.h>
11 #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
12 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
13 #include <torch/csrc/jit/passes/constant_pooling.h>
14 #include <torch/csrc/jit/passes/constant_propagation.h>
15 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
16 #include <torch/csrc/jit/passes/dead_code_elimination.h>
17 #include <torch/csrc/jit/passes/decompose_ops.h>
18 #include <torch/csrc/jit/passes/graph_fuser.h>
19 #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
20 #include <torch/csrc/jit/passes/inliner.h>
21 #include <torch/csrc/jit/passes/inplace_check.h>
22 #include <torch/csrc/jit/passes/loop_unrolling.h>
23 #include <torch/csrc/jit/passes/lower_grad_of.h>
24 #include <torch/csrc/jit/passes/lower_tuples.h>
25 #include <torch/csrc/jit/passes/pass_manager.h>
26 #include <torch/csrc/jit/passes/peephole.h>
27 #include <torch/csrc/jit/passes/remove_expands.h>
28 #include <torch/csrc/jit/passes/remove_mutation.h>
29 #include <torch/csrc/jit/passes/requires_grad_analysis.h>
30 #include <torch/csrc/jit/passes/shape_analysis.h>
31 #include <torch/csrc/jit/passes/specialize_autogradzero.h>
32 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
33 #include <torch/csrc/jit/resource_guard.h>
34 #include <torch/csrc/jit/runtime/argument_spec.h>
35 #include <torch/csrc/jit/runtime/autodiff.h>
36 #include <torch/csrc/jit/runtime/custom_operator.h>
37 #include <torch/csrc/jit/runtime/graph_executor_impl.h>
38 #include <torch/csrc/jit/runtime/interpreter.h>
39 #include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h>
40 #include <torch/csrc/jit/runtime/profiling_record.h>
41 #include <torch/csrc/jit/runtime/simple_graph_executor_impl.h>
42 
43 #include <torch/csrc/autograd/edge.h>
44 #include <torch/csrc/autograd/function.h>
45 #include <torch/csrc/jit/python/update_graph_executor_opt.h>
46 #include <torch/csrc/jit/runtime/logging.h>
47 
48 #include <cstdint>
49 #include <iterator>
50 #include <memory>
51 #include <mutex>
52 #include <unordered_map>
53 #include <utility>
54 #include <vector>
55 
56 C10_DEFINE_bool(
57     torch_jit_execution_plan_reuse_code_graph,
58     false,
59     "Directly reuse the preprocessed graph in the CodeImpl to reduce the memory consumption. This is aggressive memory saving, and please be cautious!");
60 
61 namespace torch::jit {
62 
EnableProfilingGuard()63 EnableProfilingGuard::EnableProfilingGuard() {
64   auto& executor_mode = getExecutorMode();
65   old_executor_mode = executor_mode;
66   executor_mode = true;
67   old_get_optimize = getGraphExecutorOptimize();
68   setGraphExecutorOptimize(true);
69 }
70 
~EnableProfilingGuard()71 EnableProfilingGuard::~EnableProfilingGuard() {
72   getExecutorMode() = old_executor_mode;
73   setGraphExecutorOptimize(old_get_optimize);
74 }
75 
76 namespace {
aliasAnalysisInternalSpecialCase()77 c10::AliasAnalysisKind aliasAnalysisInternalSpecialCase() {
78   return AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
79 }
80 } // namespace
81 
82 // for debugging it is helpful to be able to force autodiff subgraphs
83 // to be created, to check their correctness, even when the
84 // size of the of the subgraph is too small to be profitable.
85 thread_local bool autodiff_subgraph_inlining = true;
debugSetAutodiffSubgraphInlining(bool state)86 void debugSetAutodiffSubgraphInlining(bool state) {
87   autodiff_subgraph_inlining = state;
88 }
89 
getAutodiffSubgraphInlining()90 bool getAutodiffSubgraphInlining() {
91   return autodiff_subgraph_inlining;
92 }
93 
94 // for debugging it is helpful to be able to force fusion groups
95 // to be created
96 static std::atomic<bool> fusion_group_inlining(true);
debugSetFusionGroupInlining(bool state)97 void debugSetFusionGroupInlining(bool state) {
98   fusion_group_inlining = state;
99 }
100 
getFusionGroupInlining()101 bool getFusionGroupInlining() {
102   return fusion_group_inlining;
103 }
104 
105 thread_local std::weak_ptr<Graph> last_executed_optimized_graph;
lastExecutedOptimizedGraph()106 std::shared_ptr<Graph> lastExecutedOptimizedGraph() {
107   return last_executed_optimized_graph.lock();
108 }
109 namespace {
110 
111 using tensor_list = std::vector<at::Tensor>;
112 using Variable = autograd::Variable;
113 using autograd::variable_list;
114 
115 struct CaptureList {
CaptureListtorch::jit::__anon4f7c12220211::CaptureList116   CaptureList(size_t capture_size) {
117     capture_types_.reserve(capture_size);
118     var_captures_.reserve(capture_size); // var_captures_.size() might be
119                                          // greater than capture_size
120     ivalue_captures_.reserve(capture_size);
121   }
122 
captureTensortorch::jit::__anon4f7c12220211::CaptureList123   void captureTensor(const at::Tensor& tensor, bool is_output) {
124     var_captures_.emplace_back(Variable(tensor), is_output);
125   }
126 
capturetorch::jit::__anon4f7c12220211::CaptureList127   void capture(const IValue& val, bool is_output) {
128     if (val.isTensor()) {
129       capture_types_.emplace_back(CAPTURE_TENSOR);
130       captureTensor(val.toTensor(), is_output);
131     } else if (val.isTensorList()) {
132       //  For TensorList, we have to flatten it to Tensors during saving and
133       //  unflatten it back to TensorList when using it in backward apply().
134       //  This is to avoid any implicit mutation to TensorList happened
135       //  between forward & backward.
136       capture_types_.emplace_back(CAPTURE_LIST);
137       auto tensors = val.toTensorList();
138       sizes_.push_back(tensors.size());
139 
140       for (const auto& tensor : tensors) {
141         captureTensor(tensor, is_output);
142       }
143     } else {
144       capture_types_.emplace_back(CAPTURE_IVALUE);
145       ivalue_captures_.push_back(val);
146     }
147   }
148 
sizetorch::jit::__anon4f7c12220211::CaptureList149   size_t size() const {
150     return capture_types_.size();
151   }
152 
unpacktorch::jit::__anon4f7c12220211::CaptureList153   void unpack(Stack& stack, const std::shared_ptr<autograd::Node>& saved_for) {
154     auto var_capture_it = var_captures_.begin();
155     auto ivalue_capture_it = ivalue_captures_.begin();
156     auto size_it = sizes_.begin();
157     for (Capture capture_type : capture_types_) {
158       switch (capture_type) {
159         case CAPTURE_TENSOR: {
160           stack.emplace_back(var_capture_it->unpack(saved_for));
161           ++var_capture_it;
162         } break;
163         case CAPTURE_LIST: {
164           c10::List<at::Tensor> lst;
165           auto size = *size_it++;
166           for (const auto i : c10::irange(size)) {
167             (void)i;
168             lst.emplace_back(var_capture_it->unpack(saved_for));
169             var_capture_it++;
170           }
171           stack.emplace_back(std::move(lst));
172         } break;
173         case CAPTURE_IVALUE: {
174           stack.push_back(*ivalue_capture_it++);
175         } break;
176       }
177     }
178   }
179 
release_variablestorch::jit::__anon4f7c12220211::CaptureList180   void release_variables() {
181     for (auto& var_capture_ : var_captures_) {
182       var_capture_.reset_data();
183     }
184   }
185 
186  private:
187   enum Capture : uint8_t {
188     CAPTURE_TENSOR,
189     CAPTURE_LIST,
190     CAPTURE_IVALUE,
191   };
192 
193   std::vector<Capture> capture_types_;
194   std::vector<autograd::SavedVariable> var_captures_;
195   std::vector<IValue> ivalue_captures_;
196   std::vector<size_t> sizes_;
197 };
198 
199 // how do we turn a flattened list of tensors back into the ivalues that
200 // the DifferentiableGraphBackward expects
201 struct UnpackInstructions {
UnpackInstructionstorch::jit::__anon4f7c12220211::UnpackInstructions202   UnpackInstructions(size_t num_inputs) {
203     insts_.reserve(num_inputs);
204   }
pushTensortorch::jit::__anon4f7c12220211::UnpackInstructions205   void pushTensor() {
206     insts_.emplace_back(PUSH_TENSOR);
207   }
pushNonetorch::jit::__anon4f7c12220211::UnpackInstructions208   void pushNone() {
209     insts_.emplace_back(PUSH_NONE);
210   }
pushTensorListtorch::jit::__anon4f7c12220211::UnpackInstructions211   void pushTensorList(size_t size) {
212     insts_.emplace_back(PUSH_LIST);
213     sizes_.push_back(size);
214   }
unpacktorch::jit::__anon4f7c12220211::UnpackInstructions215   void unpack(variable_list&& inputs, Stack& stack) {
216     auto input_it = std::make_move_iterator(inputs.begin());
217     auto sizes_it = sizes_.begin();
218     for (Inst inst : insts_) {
219       switch (inst) {
220         case PUSH_TENSOR: {
221           at::Tensor t = *input_it++;
222           stack.emplace_back(std::move(t));
223         } break;
224         case PUSH_LIST: {
225           std::vector<at::Tensor> lst(input_it, input_it + *sizes_it++);
226           stack.emplace_back(lst);
227         } break;
228         case PUSH_NONE: {
229           stack.emplace_back();
230         }
231       }
232     }
233   }
234 
235  private:
236   enum Inst : uint8_t {
237     PUSH_TENSOR,
238     PUSH_LIST, // consumes one size
239     PUSH_NONE,
240   };
241   std::vector<Inst> insts_;
242   std::vector<size_t> sizes_;
243 };
244 
245 // unpack values packed by `packReturnValuesIntoTuple`
unpackReturnTuple(Stack & stack)246 static void unpackReturnTuple(Stack& stack) {
247   auto tuple = pop(stack).toTuple();
248   stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end());
249 }
250 
251 struct DifferentiableGraphBackward : public autograd::Node {
DifferentiableGraphBackwardtorch::jit::__anon4f7c12220211::DifferentiableGraphBackward252   DifferentiableGraphBackward(
253       GraphExecutor executor,
254       size_t input_size,
255       size_t capture_size)
256       : executor(std::move(executor)),
257         captures_(capture_size),
258         input_instructions_(input_size) {}
259 
applytorch::jit::__anon4f7c12220211::DifferentiableGraphBackward260   variable_list apply(variable_list&& inputs) override {
261     Stack stack;
262     stack.reserve(captures_.size() + inputs.size());
263 
264     input_instructions_.unpack(std::move(inputs), stack);
265     captures_.unpack(stack, shared_from_this());
266     GRAPH_DEBUG("Running DifferentiableGraphBackward for ", &executor);
267     executor.run(stack);
268     unpackReturnTuple(stack);
269 
270     // NB: stack.size() == num_outputs() is not always true
271     // after we added TensorList support.
272     // Example: aten::stack(Tensor[] tensors, int) where
273     // tensors = [x, x]
274     // Here stack.size()[=1] with a TensorList IValue of
275     // backward graph output.
276     // num_outputs()[=2], however, is the number of outputs of
277     // grad_fn (an autograd::Node). grad_fn's outputs are
278     // grads with regard to Tensor/Variables `x`, but not
279     // graph input TensorList [x, x]. These two grads will
280     // be accumulated to x.grad later using autograd::InputBuffer.
281     variable_list outputs;
282     outputs.reserve(num_outputs());
283     size_t output_index = 0;
284     for (IValue& v : stack) {
285       if (v.isTensorList()) {
286         for (at::Tensor tensor : v.toTensorList()) {
287           produceOutput(output_index++, std::move(tensor), outputs);
288         }
289       } else if (v.isTensor()) {
290         if (!v.toTensor().defined()) {
291           // this undefined gradient actually corresponds to a tensor list
292           if (input_tensor_lists_.count(output_index) != 0) {
293             size_t list_size = input_tensor_lists_[output_index];
294             for (size_t i = 0; i < list_size; i++) {
295               produceOutput(output_index++, {}, outputs);
296             }
297           } else {
298             produceOutput(output_index++, {}, outputs);
299           }
300         } else {
301           produceOutput(output_index++, std::move(v).toTensor(), outputs);
302         }
303       } else {
304         TORCH_INTERNAL_ASSERT_DEBUG_ONLY(v.isNone());
305         output_index++;
306         // Input grad can also be None even if it requires grad
307         // Example: `other` in expand_as(self, other)
308         outputs.emplace_back();
309       }
310     }
311     TORCH_INTERNAL_ASSERT(
312         num_outputs() == outputs.size(),
313         "DifferentiableGraphBackward: expected ",
314         num_outputs(),
315         " outputs but found ",
316         outputs.size());
317     return outputs;
318   }
319 
capturetorch::jit::__anon4f7c12220211::DifferentiableGraphBackward320   void capture(const IValue& val, bool is_output) {
321     captures_.capture(val, is_output);
322   }
323 
addOutputForTensortorch::jit::__anon4f7c12220211::DifferentiableGraphBackward324   void addOutputForTensor(const at::Tensor& tensor) {
325     auto v = Variable(tensor);
326     add_next_edge(
327         v.defined() ? torch::autograd::impl::gradient_edge(v)
328                     : autograd::Edge{});
329   }
addOutputForIValuetorch::jit::__anon4f7c12220211::DifferentiableGraphBackward330   void addOutputForIValue(const IValue& value) {
331     if (value.isTensorList()) {
332       input_tensor_lists_.insert({index_, value.toTensorList().size()});
333       for (const at::Tensor& tensor : value.toTensorList()) {
334         addOutputForTensor(tensor);
335         index_++;
336       }
337     } else if (value.isTensor()) {
338       addOutputForTensor(value.toTensor());
339       index_++;
340     } else {
341       // We could have None passed here via `Optional[Tensor]`
342       add_next_edge(autograd::Edge{});
343       index_++;
344     }
345   }
346 
addInputVariabletorch::jit::__anon4f7c12220211::DifferentiableGraphBackward347   void addInputVariable(Variable output) {
348     // NB: since our requires_grad setting is only a heuristic we might end
349     // up wanting to differentiate through integral tensors, which is
350     // generally a hard error in autograd.
351     if (at::isFloatingType(output.scalar_type()) ||
352         at::isComplexType(output.scalar_type())) {
353       autograd::create_gradient_edge(output, shared_from_this());
354       output.set_requires_grad(true);
355     } else {
356       add_input_metadata(autograd::Node::undefined_input{});
357     }
358   }
359 
addInputIValuetorch::jit::__anon4f7c12220211::DifferentiableGraphBackward360   void addInputIValue(const IValue& v) {
361     if (v.isTensorList()) {
362       auto tensors = v.toTensorList();
363       input_instructions_.pushTensorList(tensors.size());
364       for (const at::Tensor& tensor : tensors) {
365         addInputVariable(tensor);
366       }
367     } else if (v.isTensor()) {
368       input_instructions_.pushTensor();
369       addInputVariable(v.toTensor());
370     } else if (v.isNone()) {
371       input_instructions_.pushNone();
372       addInputVariable(Variable{});
373     }
374   }
375 
release_variablestorch::jit::__anon4f7c12220211::DifferentiableGraphBackward376   void release_variables() override {
377     captures_.release_variables();
378   }
379 
380  private:
produceOutputtorch::jit::__anon4f7c12220211::DifferentiableGraphBackward381   void produceOutput(size_t i, at::Tensor output, variable_list& outputs) {
382     if (task_should_compute_output(i)) {
383       const auto& edge = next_edge(i);
384       if (output.defined()) {
385         outputs.emplace_back(std::move(output));
386       } else if (edge.is_valid()) {
387         outputs.emplace_back(
388             edge.function->input_metadata(edge.input_nr).zeros_like());
389       } else {
390         outputs.emplace_back();
391       }
392     } else {
393       outputs.emplace_back();
394     }
395   }
396 
397   friend struct ExecutionPlan;
398   GraphExecutor executor;
399   CaptureList captures_;
400   UnpackInstructions input_instructions_;
401   // we need to track input lists to fwd graph
402   // since in backward graphs these will become
403   // an undefined tensors if gradients are zeros
404   // we will need to convert an undefined tensor
405   // back to a list
406   // TODO: switch to using UnpackInstructions
407   size_t index_ = 0;
408   std::map<size_t, size_t> input_tensor_lists_;
409 };
410 
411 // an optimized way of executing the subgraph computed directly on
412 // tensors rather than Variables.
413 // This will unwrap Variables, run the plan, and re-wrap them.
414 // It can optionally also have a gradient which is hooked up
415 // to the output Variables if present.
416 struct DifferentiableGraphOp {
DifferentiableGraphOptorch::jit::__anon4f7c12220211::DifferentiableGraphOp417   DifferentiableGraphOp(Gradient grad)
418       : f_ptr(std::make_shared<GraphExecutor>(grad.f, "<forward op>")),
419         legacy_f(grad.f, "<forward op>"),
420         grad(std::move(grad)),
421         grad_executor(this->grad.df, "<backward op>"),
422         num_inputs(this->grad.f->inputs().size()),
423         num_outputs(this->grad.f->outputs().size()) {}
424 
425   // XXX: keep in mind that stack can be larger than the inputs we need!
operator ()torch::jit::__anon4f7c12220211::DifferentiableGraphOp426   void operator()(Stack& stack) const {
427     auto grad_fn = std::make_shared<DifferentiableGraphBackward>(
428         grad_executor,
429         grad.df_input_vjps.size(),
430         grad.df_input_captured_inputs.size() +
431             grad.df_input_captured_outputs.size());
432 
433     {
434       auto inputs = last(stack, num_inputs);
435       // hook up the outputs of df to the gradient functions of the inputs that
436       // require gradients
437       for (auto idx : grad.df_output_vjps) {
438         grad_fn->addOutputForIValue(inputs[idx]);
439       }
440       captureInputs(*grad_fn, inputs);
441     }
442 
443     detachVariables(stack);
444     if (IsNewExecutorEnabled()) {
445       const ExecutionPlan& plan = f_ptr->getPlanFor(stack);
446       InterpreterState(plan.code).run(stack);
447     } else {
448       InterpreterState(legacy_f).run(stack);
449     }
450 
451     {
452       auto outputs = last(stack, num_outputs);
453       // hookup the gradients for the output tensors that require gradients
454       // to the inputs to our gradient function df
455       // TODO - XXX - if any output is the same tensor multiple times, views
456       // have to be setup here. We need to refactor autograd until it is safe
457       // for tensors to be constructed without all the viewing infrastructure.
458       // this is currently intentionally not done here so we can get an idea of
459       // our perf before introducing overhead for correctness
460       for (auto idx : grad.df_input_vjps) {
461         grad_fn->addInputIValue(outputs[idx]);
462       }
463       captureOutputs(*grad_fn, outputs);
464       // drop the temporary outputs so that we return the same number of
465       // outputs as if we were not also calculating gradient
466       const size_t num_temporary_outputs = num_outputs - grad.f_real_outputs;
467       stack.erase(stack.end() - num_temporary_outputs, stack.end());
468     }
469   }
470 
471  private:
472   friend GraphExecutor* detail::getGradExecutor(Operation& op);
473   friend GraphExecutor* detail::getDifferentiableGraphOpExecutor(Operation& op);
474 
detachtorch::jit::__anon4f7c12220211::DifferentiableGraphOp475   at::Tensor detach(at::Tensor t) const {
476     if (!t.defined()) {
477       return t;
478     }
479     return t.detach();
480   }
481 
detachtorch::jit::__anon4f7c12220211::DifferentiableGraphOp482   void detach(IValue& v) const {
483     if (v.isTensor()) {
484       v = IValue(detach(std::move(v).toTensor()));
485     } else if (v.isTensorList()) {
486       std::vector<at::Tensor> lst = v.toTensorVector();
487       for (auto& tensor : lst) {
488         tensor = detach(tensor);
489       }
490       v = std::move(lst);
491     }
492   }
493 
detachVariablestorch::jit::__anon4f7c12220211::DifferentiableGraphOp494   void detachVariables(Stack& stack) const {
495     // It would be nice to use an ArrayRef here, but unfortunately those can
496     // only return const references, so we need to do a bunch of indexing
497     // ourselves.
498     const int64_t stack_size = stack.size();
499     const int64_t stack_offset = stack_size - num_inputs;
500     for (const auto i : c10::irange(stack_offset, stack_size)) {
501       detach(stack[i]);
502     }
503   }
504   // Capture (save) inputs that would be required to subsequently run backwards
captureInputstorch::jit::__anon4f7c12220211::DifferentiableGraphOp505   void captureInputs(
506       DifferentiableGraphBackward& grad_fn,
507       at::ArrayRef<IValue> inputs) const {
508     for (size_t offset : grad.df_input_captured_inputs) {
509       grad_fn.capture(inputs[offset], /*is_output*/ false);
510     }
511   }
captureOutputstorch::jit::__anon4f7c12220211::DifferentiableGraphOp512   void captureOutputs(
513       DifferentiableGraphBackward& grad_fn,
514       at::ArrayRef<IValue> outputs) const {
515     for (size_t offset : grad.df_input_captured_outputs) {
516       grad_fn.capture(outputs[offset], /*is_output*/ true);
517     }
518   }
519 
520   std::shared_ptr<GraphExecutor> f_ptr;
521   Code legacy_f;
522   Gradient grad;
523   GraphExecutor grad_executor;
524 
525   const size_t num_inputs;
526   const size_t num_outputs;
527 };
528 
getGradient(const Node * n)529 Gradient getGradient(const Node* n) {
530   AT_ASSERT(n->kind() == prim::DifferentiableGraph);
531   Gradient grad;
532   grad.f = n->g(attr::Subgraph);
533   grad.df = n->g(attr::ReverseSubgraph);
534   grad.f_real_outputs = n->i(attr::f_real_outputs);
535   grad.df_input_vjps = fmap<size_t>(n->is(attr::df_input_vjps));
536   grad.df_input_captured_inputs =
537       fmap<size_t>(n->is(attr::df_input_captured_inputs));
538   grad.df_input_captured_outputs =
539       fmap<size_t>(n->is(attr::df_input_captured_outputs));
540   grad.df_output_vjps = fmap<size_t>(n->is(attr::df_output_vjps));
541   return grad;
542 }
543 } // anonymous namespace
544 
545 RegisterOperators reg_graph_executor_ops({Operator(
546     prim::DifferentiableGraph,
__anon4f7c12220302(const Node* n) 547     [](const Node* n) -> Operation {
548       return DifferentiableGraphOp(getGradient(n));
549     },
550     aliasAnalysisInternalSpecialCase())});
551 
552 namespace detail {
553 
getGradExecutor(Operation & op)554 GraphExecutor* getGradExecutor(Operation& op) {
555   if (auto diff_op = op.target<DifferentiableGraphOp>()) {
556     return &diff_op->grad_executor;
557   }
558   return nullptr;
559 }
560 
getDifferentiableGraphOpExecutor(Operation & op)561 GraphExecutor* getDifferentiableGraphOpExecutor(Operation& op) {
562   TORCH_INTERNAL_ASSERT(
563       IsNewExecutorEnabled(),
564       __FUNCTION__,
565       " is only accessible under profiling executor\n");
566   if (auto diff_op = op.target<DifferentiableGraphOp>()) {
567     return diff_op->f_ptr.get();
568   }
569   return nullptr;
570 }
571 } // namespace detail
572 
run(Stack & stack)573 void GraphExecutorImplBase::run(Stack& stack) {
574   TORCH_CHECK(
575       stack.size() >= num_inputs,
576       "expected ",
577       num_inputs,
578       " inputs, but got only ",
579       stack.size());
580 
581   C10_LOG_API_USAGE_ONCE("torch.graph_executor.run");
582   logging::getLogger()->addStatValue(
583       logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0);
584 
585   const ExecutionPlan& plan = getPlanFor(stack);
586   InterpreterState(plan.code).run(stack);
587   last_executed_optimized_graph = plan.graph;
588 }
589 
runAsync(Stack & stack,TaskLauncher taskLauncher)590 c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync(
591     Stack& stack,
592     TaskLauncher taskLauncher) {
593   TORCH_CHECK(
594       stack.size() >= num_inputs,
595       "expected ",
596       num_inputs,
597       " inputs, but got only ",
598       stack.size());
599 
600   C10_LOG_API_USAGE_ONCE("torch.graph_executor.runAsync");
601   logging::getLogger()->addStatValue(
602       logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0);
603 
604   struct Frame {
605     explicit Frame(ExecutionPlan eplan, TaskLauncher taskLauncher)
606         : plan(std::move(eplan)), state(plan.code, std::move(taskLauncher)) {}
607     ExecutionPlan plan;
608     InterpreterState state;
609   };
610   auto frame =
611       std::make_shared<Frame>(getPlanFor(stack), std::move(taskLauncher));
612   auto res = frame->state.runAsync(stack);
613   last_executed_optimized_graph = frame->plan.graph;
614   if (!res->completed()) {
615     // If not completed, persist the Frame until complete.
616     res->addCallback([frame](Future& /* unused */) {});
617   }
618   return res;
619 }
620 
621 // a Graph can be created via tracing, or via a language-based frontend
622 // GraphExecutor runs it. It can run the same graph on many different sizes
623 // and different requires_grad states, and handles specializations for each
624 // situation. GraphExecutor is completely unaware of tracing or module
625 // parameters to keep the tracing concerns separated.
626 struct GraphExecutorImpl : public GraphExecutorImplBase {
GraphExecutorImpltorch::jit::GraphExecutorImpl627   GraphExecutorImpl(
628       const std::shared_ptr<Graph>& graph,
629       std::string function_name)
630       : GraphExecutorImplBase(graph, std::move(function_name)),
631         arg_spec_creator_(*graph) {
632     logging::getLogger()->addStatValue(
633         logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
634   }
635 
getPlanFortorch::jit::GraphExecutorImpl636   const ExecutionPlan& getPlanFor(
637       Stack& stack,
638       std::optional<size_t> remaining_bailout_depth) override {
639     return getGraphExecutorOptimize() ? getOrCompile(stack)
640                                       : getOrCompileFallback();
641   }
642 
getDebugStatetorch::jit::GraphExecutorImpl643   GraphExecutorState getDebugState() override {
644     GraphExecutorState state;
645     state.graph = graph.get();
646     if (fallback) {
647       state.fallback = fallback;
648     }
649     for (auto& entry : plan_cache) {
650       state.execution_plans.emplace(entry.first, entry.second);
651     }
652     return state;
653   }
654 
655  protected:
656   friend struct GraphExecutor;
657 
getOrCompileFallbacktorch::jit::GraphExecutorImpl658   const ExecutionPlan& getOrCompileFallback() {
659     std::lock_guard<std::mutex> lock(compile_mutex);
660     if (!fallback) {
661       auto graph_ = graph->copy();
662       runRequiredPasses(graph_);
663       fallback = ExecutionPlan(graph_, function_name_);
664     }
665     return fallback;
666   }
667 
getOrCompiletorch::jit::GraphExecutorImpl668   const ExecutionPlan& getOrCompile(const Stack& stack) {
669     // outside lock guard, to minimize the time holding the lock on the fast
670     // path ArgumentSpec even computes its hashCode here.
671     ArgumentSpec spec =
672         arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack);
673     {
674       std::lock_guard<std::mutex> lock(compile_mutex);
675       auto it = plan_cache.find(spec);
676       if (it != plan_cache.end()) {
677         logging::getLogger()->addStatValue(
678             logging::runtime_counters::EXECUTION_PLAN_CACHE_HIT, 1.0);
679         return it->second;
680       }
681       auto plan = compileSpec(spec);
682       auto r = plan_cache.emplace(std::move(spec), std::move(plan));
683       logging::getLogger()->addStatValue(
684           logging::runtime_counters::EXECUTION_PLAN_CACHE_MISS, 1.0);
685       return r.first->second;
686     }
687   }
688 
compileSpectorch::jit::GraphExecutorImpl689   ExecutionPlan compileSpec(const ArgumentSpec& spec) {
690     auto opt_graph = graph->copy();
691     GRAPH_DUMP("Optimizing the following function:", opt_graph);
692     arg_spec_creator_.specializeTypes(*opt_graph, spec);
693 
694     // Phase 0. Inline functions, then clean up any artifacts that the inliner
695     //          left in that may inhibit optimization
696     Inline(*opt_graph);
697     GRAPH_DEBUG("After Inline, before LowerGradOf\n", *opt_graph);
698     LowerGradOf(*opt_graph);
699     GRAPH_DEBUG(
700         "After LowerGradOf, before specializeAutogradZero\n", *opt_graph);
701     specializeAutogradZero(opt_graph);
702     GRAPH_DEBUG(
703         "After specializeAutogradZero, before LowerSimpleTuples\n", *opt_graph);
704     LowerSimpleTuples(opt_graph);
705     GRAPH_DEBUG(
706         "After LowerSimpleTuples, before ConstantPooling\n", *opt_graph);
707     ConstantPooling(opt_graph);
708     GRAPH_DEBUG(
709         "After ConstantPooling, before runRequiredPasses\n", *opt_graph);
710 
711     // Phase 1. Specialize to input definedness (this is very important for
712     //          gradient graphs), and run required passes to bring the graph
713     //          to an executable form.
714     runRequiredPasses(opt_graph);
715     GRAPH_DEBUG(
716         "After runRequiredPasses, before ConstantPropagation\n", *opt_graph);
717 
718     // Phase 2. Propagate detailed information about the spec through the
719     //          graph (enabled more specializations in later passes).
720     //          Shape propagation sometimes depends on certain arguments being
721     //          constants, and constant propagation doesn't need shape
722     //          information anyway, so it's better to run it first.
723     ConstantPropagation(opt_graph);
724     GRAPH_DEBUG(
725         "After ConstantPropagation, before PropagateInputShapes\n", *opt_graph);
726     PropagateInputShapes(opt_graph);
727     GRAPH_DEBUG(
728         "After PropagateInputShapes, before PropagateRequiresGrad\n",
729         *opt_graph);
730     PropagateRequiresGrad(opt_graph);
731     GRAPH_DEBUG(
732         "After PropagateRequiresGrad, before runOptimization\n", *opt_graph);
733 
734     // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites
735     //          that we can still execute using autograd).
736     runOptimization(opt_graph);
737 
738     // Phase 4. If this graph will be differentiated, we need to slice out the
739     //          symbolically differentiable subgraphs for further optimizations.
740     // Phase 5. Apply non-differentiable optimizations to the graphs we've found
741     //          (or the whole graph if we know we won't need its derivative).
742     if (needsGradient(opt_graph)) {
743       auto diff_nodes = CreateAutodiffSubgraphs(
744           opt_graph,
745           autodiff_subgraph_inlining ? autodiffSubgraphNodeThreshold : 1);
746       GRAPH_DEBUG("After CreateAutodiffSubgraphs\n", *opt_graph);
747       size_t idx = 0;
748       for (Node* dnode : diff_nodes) {
749         GRAPH_DEBUG("Optimizing diff node ", idx);
750         auto diff_graph = std::move(dnode->g(attr::Subgraph));
751         Gradient gradient = differentiate(diff_graph);
752         GRAPH_DEBUG("Forward graph:\n", *(gradient.f));
753         GRAPH_DEBUG("Backward graph:\n", *(gradient.df));
754         // Run post differentiation optimizations, Autodiff will replace some
755         // parts of graph with new graph, these new graphs usually consists of
756         // control flows and miss shape information on nodes, so we run shape
757         // prop and differentiable optimizations to ensure the graph is
758         // optimized
759         PropagateInputShapes(gradient.f);
760         GRAPH_DEBUG("After PropagateInputShapes\n", *(gradient.f));
761         runOptimization(gradient.f);
762         // run non diff optimization on the forward graph
763         runNondiffOptimization(gradient.f);
764         packGradient(gradient, dnode);
765         GRAPH_DEBUG("Finished optimizing diff node ", idx++);
766       }
767       InlineAutodiffSubgraphs(
768           opt_graph,
769           autodiff_subgraph_inlining ? autodiffSubgraphInlineThreshold : 1);
770       GRAPH_DEBUG("After InlineAutodiffSubgraphs\n", *opt_graph);
771     } else {
772       runNondiffOptimization(opt_graph);
773     }
774     // Make sure there are no leftovers from any passes.
775     EliminateDeadCode(opt_graph);
776     GRAPH_DUMP("After compileSpec optimizations:", opt_graph);
777     return ExecutionPlan(opt_graph, function_name_);
778   }
779 
780   ~GraphExecutorImpl() override = default;
781 
782   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
783   ArgumentSpecCreator arg_spec_creator_;
784   // Populated only when optimize is false (and in that case plan_cache will be
785   // unused). The compiled version of graph.
786   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
787   ExecutionPlan fallback;
788 
789   // Mapping from argument configurations to optimized versions of the graph
790   // that are specialized to the spec.
791   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
792   std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache;
793 };
794 
GraphExecutor(const std::shared_ptr<Graph> & graph,std::string function_name)795 GraphExecutor::GraphExecutor(
796     const std::shared_ptr<Graph>& graph,
797     std::string function_name)
798     : pImpl(
799           IsNewExecutorEnabled()
800               ? (getProfilingMode() ?
801 
802                                     dynamic_cast<GraphExecutorImplBase*>(
803                                         new ProfilingGraphExecutorImpl(
804                                             graph,
805                                             std::move(function_name)))
806                                     : dynamic_cast<GraphExecutorImplBase*>(
807                                           new SimpleGraphExecutorImpl(
808                                               graph,
809                                               std::move(function_name))))
810               : dynamic_cast<GraphExecutorImplBase*>(
811                     new GraphExecutorImpl(graph, std::move(function_name)))) {}
812 
GraphExecutor(const std::shared_ptr<Graph> & graph,std::string function_name,ExecutorExecutionMode executor_mode)813 GraphExecutor::GraphExecutor(
814     const std::shared_ptr<Graph>& graph,
815     std::string function_name,
816     ExecutorExecutionMode executor_mode)
817     : pImpl(
818           executor_mode == ExecutorExecutionMode::SIMPLE
819               ? dynamic_cast<GraphExecutorImplBase*>(
820                     new SimpleGraphExecutorImpl(
821                         graph,
822                         std::move(function_name)))
823               : dynamic_cast<GraphExecutorImplBase*>(
824                     new ProfilingGraphExecutorImpl(
825                         graph,
826                         std::move(function_name)))) {}
827 
run(Stack & inputs)828 void GraphExecutor::run(Stack& inputs) {
829   return pImpl->run(inputs);
830 }
831 
runAsync(Stack & stack,TaskLauncher taskLauncher)832 c10::intrusive_ptr<Future> GraphExecutor::runAsync(
833     Stack& stack,
834     TaskLauncher taskLauncher) {
835   return pImpl->runAsync(stack, std::move(taskLauncher));
836 }
837 
getPlanFor(Stack & inputs,std::optional<size_t> remaining_bailout_depth)838 const ExecutionPlan& GraphExecutor::getPlanFor(
839     Stack& inputs,
840     std::optional<size_t> remaining_bailout_depth) {
841   return pImpl->getPlanFor(inputs, remaining_bailout_depth);
842 }
843 
getDebugState()844 GraphExecutorState GraphExecutor::getDebugState() {
845   return pImpl->getDebugState();
846 }
847 
debugFlushCompilationCache()848 void GraphExecutor::debugFlushCompilationCache() {
849   if (auto ppImpl =
850           std::dynamic_pointer_cast<ProfilingGraphExecutorImpl>(pImpl)) {
851     ppImpl->debugFlushCompilationCache();
852   } else {
853     // we are deprecating legacy executor
854     TORCH_INTERNAL_ASSERT(false, "Not Implemented for Legacy Executor");
855   }
856 }
857 
isOptimized() const858 bool GraphExecutor::isOptimized() const {
859   return pImpl && pImpl->isOptimized();
860 }
861 
IsNewExecutorEnabled()862 TORCH_API bool IsNewExecutorEnabled() {
863   static const auto disable_new_executor =
864       std::getenv("TORCH_JIT_DISABLE_NEW_EXECUTOR");
865   return getExecutorMode() && FLAGS_torch_jit_enable_new_executor &&
866       !disable_new_executor;
867 }
868 
runRequiredPasses(const std::shared_ptr<Graph> & g)869 void runRequiredPasses(const std::shared_ptr<Graph>& g) {
870   // implicit inserted expand nodes are not necessarily always valid
871   // when used inside script methods that might have unstable shapes
872   // we remove the implicitly created ones, and have shape analysis
873   // add valid expand nodes when the shapes are stable
874   RemoveExpands(g);
875   CanonicalizeOps(g);
876   EliminateDeadCode(g);
877 }
878 
packGradient(const Gradient & gradient,Node * dnode)879 void packGradient(const Gradient& gradient, Node* dnode) {
880   AT_ASSERT(dnode->kind() == prim::DifferentiableGraph);
881   dnode->g_(attr::Subgraph, gradient.f)
882       ->g_(attr::ReverseSubgraph, gradient.df)
883       ->i_(attr::f_real_outputs, gradient.f_real_outputs)
884       ->is_(attr::df_input_vjps, fmap<int64_t>(gradient.df_input_vjps))
885       ->is_(
886           attr::df_input_captured_inputs,
887           fmap<int64_t>(gradient.df_input_captured_inputs))
888       ->is_(
889           attr::df_input_captured_outputs,
890           fmap<int64_t>(gradient.df_input_captured_outputs))
891       ->is_(attr::df_output_vjps, fmap<int64_t>(gradient.df_output_vjps));
892 }
893 
mayIntroduceGradient(const Block * b)894 static bool mayIntroduceGradient(const Block* b) {
895   for (const Node* n : b->nodes()) {
896     if (n->kind() == prim::PythonOp)
897       return true;
898     for (const Block* bb : n->blocks()) {
899       if (mayIntroduceGradient(bb))
900         return true;
901     }
902   }
903   return false;
904 }
905 
needsGradient(const std::shared_ptr<const Graph> & graph)906 bool needsGradient(const std::shared_ptr<const Graph>& graph) {
907   if (!autograd::GradMode::is_enabled()) {
908     return false;
909   }
910 
911   if (mayIntroduceGradient(graph->block())) {
912     return true;
913   }
914 
915   for (const Value* input : graph->inputs()) {
916     if (input->type()->requires_grad()) {
917       return true;
918     }
919   }
920 
921   return false;
922 }
923 
runNondiffOptimization(std::shared_ptr<Graph> & graph,bool strict_fuser_check)924 void runNondiffOptimization(
925     std::shared_ptr<Graph>& graph,
926     bool strict_fuser_check) {
927   GRAPH_DEBUG(
928       "Before customPrePasses (beginning of runNondiffOptimization)\n", *graph);
929   // Run custom passes that different backends can register.
930   for (const auto& passPair : getCustomPrePasses()) {
931     passPair.first(graph);
932   }
933   GRAPH_DEBUG("After customPrePasses\n", *graph);
934 
935   // decomposition pass, decompose certain ops that will be used in the
936   // following passes (like batchmm and jit fusion)
937   DecomposeOps(graph);
938   GRAPH_DEBUG("After DecomposeOps\n", *graph);
939 
940   // TupleConstruct / TupleUnpack pairs can still be present at this point
941   // and must be removed for fusion.
942   LowerSimpleTuples(graph);
943   GRAPH_DEBUG("After LowerSimpleTuples, before BatchMM\n", *graph);
944 
945   // Rewrite subgraphs with many MMs into expressions that batch them.
946   BatchMM(graph);
947 
948   GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph);
949   if (getExecutorMode()) {
950     if (tensorExprFuserEnabled()) {
951       auto min_size = getFusionGroupInlining() ? 2 : 1;
952       auto dyn_shapes = tensorExprDynamicShapeFusionEnabled();
953       FuseTensorExprs(graph, min_size, /*composed_op*/ false, dyn_shapes);
954     }
955   } else {
956     FuseGraph(graph, strict_fuser_check);
957   }
958   GRAPH_DEBUG("After Fusion\n", *graph);
959 
960   // Run custom post-fusion passes
961   for (const auto& passPair : getCustomPostPasses()) {
962     passPair.first(graph);
963   }
964   GRAPH_DEBUG(
965       "After customPostPasses (end of runNondiffOptimization)\n", *graph);
966 }
967 
runOptimization(std::shared_ptr<Graph> & graph,bool unroll_non_constant_loops,bool const_prop_user_classes)968 void runOptimization(
969     std::shared_ptr<Graph>& graph,
970     bool unroll_non_constant_loops,
971     bool const_prop_user_classes) {
972   // Basic graph preprocessing to eliminate noise.
973   GRAPH_DEBUG(
974       "Before EliminateDeadCode (beginning of runOptimization)\n", *graph);
975   EliminateDeadCode(graph);
976   GRAPH_DEBUG(
977       "After EliminateDeadCode, before EliminateCommonSubexpression\n", *graph);
978   EliminateCommonSubexpression(graph);
979   GRAPH_DEBUG(
980       "After EliminateCommonSubexpression , before PeepholeOptimize\n", *graph);
981 
982   PeepholeOptimize(graph);
983   GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
984 
985   if (const_prop_user_classes) {
986     ConstantPropagation(graph);
987   } else {
988     ConstantPropagation(graph, true);
989   }
990   GRAPH_DEBUG("After ConstantPropagation, before ConstantPooling\n", *graph);
991 
992   ConstantPooling(graph);
993   GRAPH_DEBUG("After ConstantPooling\n", *graph);
994 
995   // Unroll small loops, and eliminate expressions that are the same at every
996   // iteration.
997   bool unroll_success = false;
998   if (unroll_non_constant_loops) {
999     unroll_success = UnrollLoops(graph);
1000     GRAPH_DEBUG("After UnrollLoops, before RemoveListMutation\n", *graph);
1001   } else {
1002     unroll_success = UnrollConstantLoops(graph);
1003     GRAPH_DEBUG(
1004         "After UnrollConstantLoops, before RemoveListMutation\n", *graph);
1005   }
1006 
1007   if (unroll_success) {
1008     // run again with unrolled loops
1009     RemoveListMutation(graph);
1010     GRAPH_DEBUG("After RemoveListMutation, before PeepholeOptimize\n", *graph);
1011     PeepholeOptimize(graph);
1012     GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
1013     ConstantPropagation(graph);
1014     GRAPH_DEBUG("After ConstantPropagation\n", *graph);
1015   }
1016 
1017   EliminateCommonSubexpression(graph);
1018   GRAPH_DEBUG(
1019       "After EliminateCommonSubexpression, before CheckInplace\n", *graph);
1020   CheckInplace(graph);
1021   GRAPH_DEBUG("After CheckInplace (end of runOptimization)\n", *graph);
1022 }
1023 
replaceBlockWithFallbackGraph(Block * b,ArrayRef<Value * > inputs)1024 Node* replaceBlockWithFallbackGraph(Block* b, ArrayRef<Value*> inputs) {
1025   auto graph = std::make_shared<Graph>();
1026 
1027   // we are copying the block inside If or prim::Loop otherwise we are copying
1028   // the whole graph we need to differentiate the two cases  because cloneFrom
1029   // automatically adds inputs if we are copying graph's block and we will
1030   //  need the inputs from a user otherwise
1031   if (b->owningNode() != nullptr) {
1032     std::unordered_map<Value*, Value*> input_mapping;
1033     auto value_map = [&input_mapping](Value* v) { return input_mapping[v]; };
1034     for (auto inp : inputs) {
1035       input_mapping[inp] = graph->block()->addInput();
1036     }
1037     graph->block()->cloneFrom(b, value_map);
1038   } else {
1039     auto value_map = [](Value* v) { return v; };
1040     graph->block()->cloneFrom(b, value_map);
1041   }
1042 
1043   auto fallback = b->owningGraph()->create(
1044       prim::FallbackGraph, inputs, b->outputs().size());
1045   fallback->g_(attr::Subgraph, graph);
1046   b->prependNode(fallback);
1047 
1048   for (const auto i : c10::irange(inputs.size())) {
1049     graph->inputs()[i]->setType(inputs[i]->type());
1050     graph->inputs()[i]->copyMetadata(inputs[i]);
1051   }
1052 
1053   for (const auto i : c10::irange(b->outputs().size())) {
1054     fallback->output(i)->setType(b->outputs()[i]->type());
1055     fallback->output(i)->copyMetadata(b->outputs()[i]);
1056     b->replaceOutput(i, fallback->output(i));
1057   }
1058 
1059   ProfilingRecord::removeProfilingNodes(graph->block());
1060 
1061   for (auto it = b->nodes().rbegin(); it != fallback->iterator(); it++) {
1062     it.destroyCurrent();
1063   }
1064 
1065   return fallback;
1066 }
1067 
1068 } // namespace torch::jit
1069