1
2 #include <ATen/ATen.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/core/ivalue.h>
5 #include <ATen/core/symbol.h>
6 #include <torch/csrc/jit/ir/ir_views.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/dead_code_elimination.h>
9 #include <torch/csrc/jit/passes/freeze_module.h>
10 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
11 #include <torch/csrc/jit/passes/inliner.h>
12 #include <torch/csrc/jit/passes/insert_guards.h>
13 #include <torch/csrc/jit/passes/remove_mutation.h>
14 #include <torch/csrc/jit/runtime/graph_executor.h>
15 #include <torch/csrc/jit/runtime/interpreter.h>
16 #include <torch/csrc/jit/runtime/jit_trace.h>
17 #include <torch/csrc/jit/runtime/profiling_record.h>
18 #include <unordered_map>
19
20 namespace torch::jit {
21
22 namespace {
23
24 // A helper structure to maintain the mappings
25 // between values from a scripted graph and
26 // a traced graph
27 struct TracingData {
28 std::unordered_map<Value*, Value*> old_to_new_;
29 std::shared_ptr<Graph> traced_graph_ = nullptr;
30
TracingDatatorch::jit::__anon0d138ff70111::TracingData31 TracingData() {
32 traced_graph_ = std::make_shared<Graph>();
33 }
34 };
35
36 // create a node in the traced graph that corresponds to `node`
37 // in the scripted graph. Similar to how `cloneNode` works
traceNode(Node * node,TracingData & td,Stack & stack)38 Node* traceNode(Node* node, TracingData& td, Stack& stack) {
39 GRAPH_DEBUG("Tracing node ", getHeader(node));
40 auto* block = td.traced_graph_->block();
41 auto env = [&td](Value* v) { return td.old_to_new_.at(v); };
42
43 auto new_node = block->appendNode(td.traced_graph_->createClone(node, env));
44 for (size_t i = 0; i < node->outputs().size(); ++i) {
45 auto oo = node->outputs()[i];
46 auto no = new_node->outputs()[i];
47 no->copyMetadata(oo);
48 td.old_to_new_[oo] = no;
49 GRAPH_DEBUG(
50 "Mapping ",
51 oo->debugName(),
52 " to ",
53 no->debugName()); // old to new outputs
54 }
55 return new_node;
56 }
57
eraseAllOutputs(Node * opt_pn)58 void eraseAllOutputs(Node* opt_pn) {
59 for (auto i = static_cast<int64_t>(opt_pn->outputs().size()) - 1; i >= 0;
60 i--) {
61 opt_pn->eraseOutput(i);
62 }
63 }
64
65 void insertTracingNodes(Block*, ProfilingRecord*, TracingData&);
66
67 // The subtlety in `createPropNodeForIfBlock` is that we need to create
68 // a "propagate" node that will propagate the mapping between the outputs
69 // of a then/else block and the outputs in the traced graph onto the outputs
70 // of the if node in the scripted node. Note, if nodes will disappear in the
71 // the traced graph but they are still used in the scripted graph.
createPropNodeForIfBlock(Block * b,Node * n,ProfilingRecord * pr,TracingData & td)72 void createPropNodeForIfBlock(
73 Block* b,
74 Node* n,
75 ProfilingRecord* pr,
76 TracingData& td) {
77 std::vector<Value*> empty_values{};
78 auto opt_pn = pr->createProfileIValueNode(empty_values);
79 eraseAllOutputs(opt_pn);
80 insertTracingNodes(b, pr, td);
81 b->appendNode(opt_pn);
82 std::function<void(Stack&)> optional_profiler =
83 [pr, n, b, &td](Stack& stack) {
84 std::lock_guard<std::mutex> lock(pr->mutex_);
85
86 // frame_id is unused
87 int64_t frame_id = 0;
88 pop(stack, frame_id);
89
90 for (size_t i = 0; i < b->outputs().size(); i++) {
91 // propagate a then-block or else-output to an if-output
92 auto nbo = td.old_to_new_.at(b->outputs()[i]);
93 td.old_to_new_[n->outputs()[i]] = nbo;
94 GRAPH_DEBUG(
95 "Map ",
96 td.old_to_new_[n->outputs()[i]]->debugName(),
97 " to ",
98 nbo->debugName());
99 }
100 };
101
102 // uncomment for debugging
103 // opt_pn->i_(Symbol::attr("propagate"), 1);
104 opt_pn->setCallback(optional_profiler);
105 }
106
107 // loop counter is implicit in the loop body outputs, we need to make
108 // it explicit so it can used in 2+ iterations
traceLoopCounter(Node * n,ProfilingRecord * pr,TracingData & td)109 void traceLoopCounter(Node* n, ProfilingRecord* pr, TracingData& td) {
110 LoopView lv(n);
111 auto opt_pn = pr->createProfileIValueNode(lv.currentTripCount());
112 eraseAllOutputs(opt_pn);
113 lv.bodyBlock()->prependNode(opt_pn);
114 std::function<void(Stack&)> optional_profiler = [pr, n, &td](Stack& stack) {
115 std::lock_guard<std::mutex> lock(pr->mutex_);
116 // frame_id is unused
117 int64_t frame_id = 0;
118 pop(stack, frame_id);
119 int64_t loop_counter = 0;
120 pop(stack, loop_counter);
121 WithInsertPoint wip(td.traced_graph_->block());
122 auto lc = td.traced_graph_->insertConstant(loop_counter);
123 LoopView lv(n);
124 td.old_to_new_[lv.currentTripCount()] = lc;
125 };
126
127 // uncomment for debugging
128 // opt_pn->i_(Symbol::attr("loop_counter"), 1);
129 opt_pn->setCallback(optional_profiler);
130 }
131
132 // Similar to how we propagate the mappings for If nodes, we need to propagate
133 // the mappings from the loop body to the beginning of the block in case we
134 // run another iteration and to the outputs of the Loop node, for any logic
135 // downstream that uses the output values of the loop node
traceLoop(Node * n,ProfilingRecord * pr,TracingData & td)136 static void traceLoop(Node* n, ProfilingRecord* pr, TracingData& td) {
137 std::vector<Value*> empty_values{};
138
139 // this is a propagation node for block inputs (phi values)
140 // these come from either `prim::Loop` inputs or loop body outputs
141 {
142 auto opt_pn = pr->createProfileIValueNode(empty_values);
143 eraseAllOutputs(opt_pn);
144 opt_pn->insertBefore(n);
145 LoopView lv(n);
146 std::function<void(Stack&)> optional_profiler = [pr, n, &td](Stack& stack) {
147 std::lock_guard<std::mutex> lock(pr->mutex_);
148
149 // frame_id is unused
150 int64_t frame_id = 0;
151 pop(stack, frame_id);
152
153 LoopView lv(n);
154 TORCH_INTERNAL_ASSERT(
155 lv.bodyCarriedInputs().size() == lv.carriedInputs().size());
156 for (size_t i = 0; i < lv.bodyCarriedInputs().size(); i++) {
157 auto bno = td.old_to_new_.at(lv.carriedInputs()[i]);
158 td.old_to_new_[lv.bodyCarriedInputs()[i]] = bno;
159 GRAPH_DEBUG(
160 "Map ",
161 td.old_to_new_[lv.bodyCarriedInputs()[i]]->debugName(),
162 " to ",
163 bno->debugName());
164 }
165 };
166
167 // uncomment for debugging
168 // opt_pn->i_(Symbol::attr("loop_entry"), 1);
169 opt_pn->setCallback(optional_profiler);
170 }
171
172 {
173 insertTracingNodes(LoopView(n).bodyBlock(), pr, td);
174 traceLoopCounter(n, pr, td);
175 }
176
177 // this is a propagation node for loop outputs
178 {
179 auto opt_pn = pr->createProfileIValueNode(empty_values);
180 eraseAllOutputs(opt_pn);
181 LoopView(n).bodyBlock()->appendNode(opt_pn);
182
183 // opt_pn->i_(Symbol::attr("loop_propagate"), 1);
184
185 std::function<void(Stack&)> optional_profiler = [pr, n, &td](Stack& stack) {
186 std::lock_guard<std::mutex> lock(pr->mutex_);
187
188 // frame_id is unused
189 int64_t frame_id = 0;
190 pop(stack, frame_id);
191
192 LoopView lv(n);
193
194 TORCH_INTERNAL_ASSERT(
195 lv.bodyCarriedOutputs().size() == lv.carriedOutputs().size());
196 for (size_t i = 0; i < lv.bodyCarriedOutputs().size(); i++) {
197 auto bno = td.old_to_new_.at(lv.bodyCarriedOutputs()[i]);
198 td.old_to_new_[lv.carriedOutputs()[i]] = bno;
199 GRAPH_DEBUG(
200 "Map ",
201 td.old_to_new_[lv.bodyCarriedOutputs()[i]]->debugName(),
202 " to ",
203 bno->debugName());
204 }
205 };
206
207 // uncomment for debugging
208 // opt_pn->i_(Symbol::attr("loop_exit"), 1);
209 opt_pn->setCallback(optional_profiler);
210 }
211 }
212
213 // walks all the nodes in a block and adds profiled nodes to each node
214 // see the comment for `optional_profiler` below
insertTracingNodes(Block * block,ProfilingRecord * pr,TracingData & td)215 void insertTracingNodes(Block* block, ProfilingRecord* pr, TracingData& td) {
216 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
217 auto n = *it;
218 it++;
219
220 GRAPH_DEBUG("Inserting trace for ", getHeader(n));
221 if (n->kind() == prim::If) {
222 IfView ifv(n);
223 createPropNodeForIfBlock(ifv.thenBlock(), n, pr, td);
224 createPropNodeForIfBlock(ifv.elseBlock(), n, pr, td);
225 continue;
226 }
227
228 if (n->kind() == prim::Loop) {
229 traceLoop(n, pr, td);
230 continue;
231 }
232
233 TORCH_INTERNAL_ASSERT(n->blocks().empty());
234 auto opt_pn = pr->createProfileIValueNode(n->outputs());
235 eraseAllOutputs(opt_pn);
236 opt_pn->insertAfter(n);
237
238 // we only use the `opt_pn->node()` to trigger the handler
239 // we still capture the actual scripted node `n` we want to trace
240 // we look at its inputs, map them to the inputs in the traced graph
241 // and create a new node with `traceNode`
242 std::function<void(Stack&)> optional_profiler = [pr, n, &td](Stack& stack) {
243 std::lock_guard<std::mutex> lock(pr->mutex_);
244
245 // frame_id is unused
246 int64_t frame_id = 0;
247 pop(stack, frame_id);
248
249 GRAPH_DEBUG("Tracing ", getHeader(n));
250 auto tracer = traceNode(n, td, stack);
251 auto outputs_size = n->outputs().size();
252 auto iivs = pop(stack, outputs_size);
253 for (size_t j = 0; j < outputs_size; j++) {
254 auto& iiv = iivs[j];
255 if (iiv.isTensor()) {
256 auto t = iiv.toTensor();
257 auto type = t.defined() ? tensorTypeInCurrentExecutionContext(t)
258 : TensorType::get();
259 tracer->outputs().at(j)->setType(type);
260 }
261 }
262 };
263
264 opt_pn->setCallback(optional_profiler);
265 }
266 }
267 } // namespace
268
269 // To trace graph we create a profile node for every one
270 // in a scripted graph. When a profiled node handler runs
271 // we insert a new traced node in a trace graph
272 // If the profiled node handler is called in a loop
273 // we will have multiple nodes.
274 // We also maintain the mapping between the outputs of traced
275 // nodes and the outputs of the node in the scripted graph.
276 // There are a few subtleties with tracing Ifs and Loops
277 // discussed above
TraceGraph(const std::shared_ptr<Graph> & graph,Stack & stack)278 std::shared_ptr<Graph> TraceGraph(
279 const std::shared_ptr<Graph>& graph,
280 Stack& stack) {
281 TracingData td;
282 GRAPH_DUMP("Before Inline:", graph);
283 Inline(*graph);
284 EliminateDeadCode(graph);
285 GRAPH_DUMP("After Inline:", graph);
286 auto pr = ProfilingRecord::instrumentGraph(graph);
287 for (auto inp : pr->profiled_graph_->inputs()) {
288 auto ni = td.traced_graph_->addInput();
289 ni->copyMetadata(inp);
290 ni->setType(ni->type());
291 td.old_to_new_[inp] = ni;
292 }
293
294 // Set type of the graph inputs using the inputs from the stack.
295 // This needs to be done before running the interpreter because the stack
296 // will only have the outputs after the run.
297 for (auto i : c10::irange(stack.size())) {
298 if (stack[i].isTensor()) {
299 td.traced_graph_->inputs().at(i)->setType(
300 tensorTypeInCurrentExecutionContext(stack[i].toTensor()));
301 }
302 }
303
304 ProfilingRecord::removeProfileCounter(pr->profiled_graph_->block());
305 ProfilingRecord::removeProfilingNodes(pr->profiled_graph_->block());
306 insertTracingNodes(pr->profiled_graph_->block(), pr.get(), td);
307 GRAPH_DUMP("Profiling Graph:", pr->profiled_graph_);
308 Code cd(pr->profiled_graph_, "");
309 InterpreterState is{cd};
310 is.run(stack);
311 for (auto out : pr->profiled_graph_->outputs()) {
312 td.traced_graph_->block()->registerOutput(td.old_to_new_.at(out));
313 }
314
315 GRAPH_DUMP("Traced graph:", td.traced_graph_);
316 return td.traced_graph_;
317 }
318 } // namespace torch::jit
319