xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/jit_trace.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 #include <ATen/ATen.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/core/ivalue.h>
5 #include <ATen/core/symbol.h>
6 #include <torch/csrc/jit/ir/ir_views.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/dead_code_elimination.h>
9 #include <torch/csrc/jit/passes/freeze_module.h>
10 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
11 #include <torch/csrc/jit/passes/inliner.h>
12 #include <torch/csrc/jit/passes/insert_guards.h>
13 #include <torch/csrc/jit/passes/remove_mutation.h>
14 #include <torch/csrc/jit/runtime/graph_executor.h>
15 #include <torch/csrc/jit/runtime/interpreter.h>
16 #include <torch/csrc/jit/runtime/jit_trace.h>
17 #include <torch/csrc/jit/runtime/profiling_record.h>
18 #include <unordered_map>
19 
20 namespace torch::jit {
21 
22 namespace {
23 
24 // A helper structure to maintain the mappings
25 // between values from a scripted graph and
26 // a traced graph
27 struct TracingData {
28   std::unordered_map<Value*, Value*> old_to_new_;
29   std::shared_ptr<Graph> traced_graph_ = nullptr;
30 
TracingDatatorch::jit::__anon0d138ff70111::TracingData31   TracingData() {
32     traced_graph_ = std::make_shared<Graph>();
33   }
34 };
35 
36 // create a node in the traced graph that corresponds to `node`
37 // in the scripted graph. Similar to how `cloneNode` works
traceNode(Node * node,TracingData & td,Stack & stack)38 Node* traceNode(Node* node, TracingData& td, Stack& stack) {
39   GRAPH_DEBUG("Tracing node ", getHeader(node));
40   auto* block = td.traced_graph_->block();
41   auto env = [&td](Value* v) { return td.old_to_new_.at(v); };
42 
43   auto new_node = block->appendNode(td.traced_graph_->createClone(node, env));
44   for (size_t i = 0; i < node->outputs().size(); ++i) {
45     auto oo = node->outputs()[i];
46     auto no = new_node->outputs()[i];
47     no->copyMetadata(oo);
48     td.old_to_new_[oo] = no;
49     GRAPH_DEBUG(
50         "Mapping ",
51         oo->debugName(),
52         " to ",
53         no->debugName()); // old to new outputs
54   }
55   return new_node;
56 }
57 
eraseAllOutputs(Node * opt_pn)58 void eraseAllOutputs(Node* opt_pn) {
59   for (auto i = static_cast<int64_t>(opt_pn->outputs().size()) - 1; i >= 0;
60        i--) {
61     opt_pn->eraseOutput(i);
62   }
63 }
64 
65 void insertTracingNodes(Block*, ProfilingRecord*, TracingData&);
66 
67 // The subtlety in `createPropNodeForIfBlock` is that we need to create
68 // a "propagate" node that will propagate the mapping between the outputs
69 // of a then/else block and the outputs in the traced graph onto the outputs
70 // of the if node in the scripted node. Note, if nodes will disappear in the
71 // the traced graph but they are still used in the scripted graph.
createPropNodeForIfBlock(Block * b,Node * n,ProfilingRecord * pr,TracingData & td)72 void createPropNodeForIfBlock(
73     Block* b,
74     Node* n,
75     ProfilingRecord* pr,
76     TracingData& td) {
77   std::vector<Value*> empty_values{};
78   auto opt_pn = pr->createProfileIValueNode(empty_values);
79   eraseAllOutputs(opt_pn);
80   insertTracingNodes(b, pr, td);
81   b->appendNode(opt_pn);
82   std::function<void(Stack&)> optional_profiler =
83       [pr, n, b, &td](Stack& stack) {
84         std::lock_guard<std::mutex> lock(pr->mutex_);
85 
86         // frame_id is unused
87         int64_t frame_id = 0;
88         pop(stack, frame_id);
89 
90         for (size_t i = 0; i < b->outputs().size(); i++) {
91           // propagate a then-block or else-output to an if-output
92           auto nbo = td.old_to_new_.at(b->outputs()[i]);
93           td.old_to_new_[n->outputs()[i]] = nbo;
94           GRAPH_DEBUG(
95               "Map ",
96               td.old_to_new_[n->outputs()[i]]->debugName(),
97               " to ",
98               nbo->debugName());
99         }
100       };
101 
102   // uncomment for debugging
103   // opt_pn->i_(Symbol::attr("propagate"), 1);
104   opt_pn->setCallback(optional_profiler);
105 }
106 
107 // loop counter is implicit in the loop body outputs, we need to make
108 // it explicit so it can used in 2+ iterations
traceLoopCounter(Node * n,ProfilingRecord * pr,TracingData & td)109 void traceLoopCounter(Node* n, ProfilingRecord* pr, TracingData& td) {
110   LoopView lv(n);
111   auto opt_pn = pr->createProfileIValueNode(lv.currentTripCount());
112   eraseAllOutputs(opt_pn);
113   lv.bodyBlock()->prependNode(opt_pn);
114   std::function<void(Stack&)> optional_profiler = [pr, n, &td](Stack& stack) {
115     std::lock_guard<std::mutex> lock(pr->mutex_);
116     // frame_id is unused
117     int64_t frame_id = 0;
118     pop(stack, frame_id);
119     int64_t loop_counter = 0;
120     pop(stack, loop_counter);
121     WithInsertPoint wip(td.traced_graph_->block());
122     auto lc = td.traced_graph_->insertConstant(loop_counter);
123     LoopView lv(n);
124     td.old_to_new_[lv.currentTripCount()] = lc;
125   };
126 
127   // uncomment for debugging
128   // opt_pn->i_(Symbol::attr("loop_counter"), 1);
129   opt_pn->setCallback(optional_profiler);
130 }
131 
132 // Similar to how we propagate the mappings for If nodes, we need to propagate
133 // the mappings from the loop body to the beginning of the block in case we
134 // run another iteration and to the outputs of the Loop node, for any logic
135 // downstream that uses the output values of the loop node
traceLoop(Node * n,ProfilingRecord * pr,TracingData & td)136 static void traceLoop(Node* n, ProfilingRecord* pr, TracingData& td) {
137   std::vector<Value*> empty_values{};
138 
139   // this is a propagation node for block inputs (phi values)
140   // these come from either `prim::Loop` inputs or loop body outputs
141   {
142     auto opt_pn = pr->createProfileIValueNode(empty_values);
143     eraseAllOutputs(opt_pn);
144     opt_pn->insertBefore(n);
145     LoopView lv(n);
146     std::function<void(Stack&)> optional_profiler = [pr, n, &td](Stack& stack) {
147       std::lock_guard<std::mutex> lock(pr->mutex_);
148 
149       // frame_id is unused
150       int64_t frame_id = 0;
151       pop(stack, frame_id);
152 
153       LoopView lv(n);
154       TORCH_INTERNAL_ASSERT(
155           lv.bodyCarriedInputs().size() == lv.carriedInputs().size());
156       for (size_t i = 0; i < lv.bodyCarriedInputs().size(); i++) {
157         auto bno = td.old_to_new_.at(lv.carriedInputs()[i]);
158         td.old_to_new_[lv.bodyCarriedInputs()[i]] = bno;
159         GRAPH_DEBUG(
160             "Map ",
161             td.old_to_new_[lv.bodyCarriedInputs()[i]]->debugName(),
162             " to ",
163             bno->debugName());
164       }
165     };
166 
167     // uncomment for debugging
168     // opt_pn->i_(Symbol::attr("loop_entry"), 1);
169     opt_pn->setCallback(optional_profiler);
170   }
171 
172   {
173     insertTracingNodes(LoopView(n).bodyBlock(), pr, td);
174     traceLoopCounter(n, pr, td);
175   }
176 
177   // this is a propagation node for loop outputs
178   {
179     auto opt_pn = pr->createProfileIValueNode(empty_values);
180     eraseAllOutputs(opt_pn);
181     LoopView(n).bodyBlock()->appendNode(opt_pn);
182 
183     // opt_pn->i_(Symbol::attr("loop_propagate"), 1);
184 
185     std::function<void(Stack&)> optional_profiler = [pr, n, &td](Stack& stack) {
186       std::lock_guard<std::mutex> lock(pr->mutex_);
187 
188       // frame_id is unused
189       int64_t frame_id = 0;
190       pop(stack, frame_id);
191 
192       LoopView lv(n);
193 
194       TORCH_INTERNAL_ASSERT(
195           lv.bodyCarriedOutputs().size() == lv.carriedOutputs().size());
196       for (size_t i = 0; i < lv.bodyCarriedOutputs().size(); i++) {
197         auto bno = td.old_to_new_.at(lv.bodyCarriedOutputs()[i]);
198         td.old_to_new_[lv.carriedOutputs()[i]] = bno;
199         GRAPH_DEBUG(
200             "Map ",
201             td.old_to_new_[lv.bodyCarriedOutputs()[i]]->debugName(),
202             " to ",
203             bno->debugName());
204       }
205     };
206 
207     // uncomment for debugging
208     // opt_pn->i_(Symbol::attr("loop_exit"), 1);
209     opt_pn->setCallback(optional_profiler);
210   }
211 }
212 
213 // walks all the nodes in a block and adds profiled nodes to each node
214 // see the comment for `optional_profiler` below
insertTracingNodes(Block * block,ProfilingRecord * pr,TracingData & td)215 void insertTracingNodes(Block* block, ProfilingRecord* pr, TracingData& td) {
216   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
217     auto n = *it;
218     it++;
219 
220     GRAPH_DEBUG("Inserting trace for ", getHeader(n));
221     if (n->kind() == prim::If) {
222       IfView ifv(n);
223       createPropNodeForIfBlock(ifv.thenBlock(), n, pr, td);
224       createPropNodeForIfBlock(ifv.elseBlock(), n, pr, td);
225       continue;
226     }
227 
228     if (n->kind() == prim::Loop) {
229       traceLoop(n, pr, td);
230       continue;
231     }
232 
233     TORCH_INTERNAL_ASSERT(n->blocks().empty());
234     auto opt_pn = pr->createProfileIValueNode(n->outputs());
235     eraseAllOutputs(opt_pn);
236     opt_pn->insertAfter(n);
237 
238     // we only use the `opt_pn->node()` to trigger the handler
239     // we still capture the actual scripted node `n` we want to trace
240     // we look at its inputs, map them to the inputs in the traced graph
241     // and create a new node with `traceNode`
242     std::function<void(Stack&)> optional_profiler = [pr, n, &td](Stack& stack) {
243       std::lock_guard<std::mutex> lock(pr->mutex_);
244 
245       // frame_id is unused
246       int64_t frame_id = 0;
247       pop(stack, frame_id);
248 
249       GRAPH_DEBUG("Tracing ", getHeader(n));
250       auto tracer = traceNode(n, td, stack);
251       auto outputs_size = n->outputs().size();
252       auto iivs = pop(stack, outputs_size);
253       for (size_t j = 0; j < outputs_size; j++) {
254         auto& iiv = iivs[j];
255         if (iiv.isTensor()) {
256           auto t = iiv.toTensor();
257           auto type = t.defined() ? tensorTypeInCurrentExecutionContext(t)
258                                   : TensorType::get();
259           tracer->outputs().at(j)->setType(type);
260         }
261       }
262     };
263 
264     opt_pn->setCallback(optional_profiler);
265   }
266 }
267 } // namespace
268 
269 // To trace graph we create a profile node for every one
270 // in a scripted graph. When a profiled node handler runs
271 // we insert a new traced node in a trace graph
272 // If the profiled node handler is called in a loop
273 // we will have multiple nodes.
274 // We also maintain the mapping between the outputs of traced
275 // nodes and the outputs of the node in the scripted graph.
276 // There are a few subtleties with tracing Ifs and Loops
277 // discussed above
TraceGraph(const std::shared_ptr<Graph> & graph,Stack & stack)278 std::shared_ptr<Graph> TraceGraph(
279     const std::shared_ptr<Graph>& graph,
280     Stack& stack) {
281   TracingData td;
282   GRAPH_DUMP("Before Inline:", graph);
283   Inline(*graph);
284   EliminateDeadCode(graph);
285   GRAPH_DUMP("After Inline:", graph);
286   auto pr = ProfilingRecord::instrumentGraph(graph);
287   for (auto inp : pr->profiled_graph_->inputs()) {
288     auto ni = td.traced_graph_->addInput();
289     ni->copyMetadata(inp);
290     ni->setType(ni->type());
291     td.old_to_new_[inp] = ni;
292   }
293 
294   // Set type of the graph inputs using the inputs from the stack.
295   // This needs to be done before running the interpreter because the stack
296   // will only have the outputs after the run.
297   for (auto i : c10::irange(stack.size())) {
298     if (stack[i].isTensor()) {
299       td.traced_graph_->inputs().at(i)->setType(
300           tensorTypeInCurrentExecutionContext(stack[i].toTensor()));
301     }
302   }
303 
304   ProfilingRecord::removeProfileCounter(pr->profiled_graph_->block());
305   ProfilingRecord::removeProfilingNodes(pr->profiled_graph_->block());
306   insertTracingNodes(pr->profiled_graph_->block(), pr.get(), td);
307   GRAPH_DUMP("Profiling Graph:", pr->profiled_graph_);
308   Code cd(pr->profiled_graph_, "");
309   InterpreterState is{cd};
310   is.run(stack);
311   for (auto out : pr->profiled_graph_->outputs()) {
312     td.traced_graph_->block()->registerOutput(td.old_to_new_.at(out));
313   }
314 
315   GRAPH_DUMP("Traced graph:", td.traced_graph_);
316   return td.traced_graph_;
317 }
318 } // namespace torch::jit
319