1 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
2
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/passes/canonicalize.h>
8 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
9 #include <torch/csrc/jit/passes/remove_redundant_profiles.h>
10 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
11 #include <torch/csrc/jit/runtime/autodiff.h>
12
13 namespace torch::jit {
14
15 namespace {
16
17 struct WorkBlock : public std::pair<Node*, Node*> {
18 using pair::pair;
19
begintorch::jit::__anon5a51c5e00111::WorkBlock20 Node* begin() {
21 return this->first;
22 }
endtorch::jit::__anon5a51c5e00111::WorkBlock23 Node* end() {
24 return this->second;
25 }
26 };
27
28 class SubgraphSlicer {
29 public:
SubgraphSlicer(Block * block,std::shared_ptr<Graph> graph,size_t minSubgraphSize,AliasDb & aliasDb,std::vector<Node * > & diff_nodes)30 SubgraphSlicer(
31 Block* block,
32 std::shared_ptr<Graph> graph,
33 size_t minSubgraphSize,
34 AliasDb& aliasDb,
35 std::vector<Node*>& diff_nodes)
36 : block_(block),
37 graph_(std::move(graph)),
38 minSubgraphSize_(minSubgraphSize),
39 aliasDb_(aliasDb),
40 diff_nodes_(diff_nodes) {}
41
run()42 void run() {
43 // We maintain alias db correctness in-place while building up the autodiff
44 // subgraphs, however it is difficult to preserve correctness when
45 // un-inlining autodiff subgraphs. We first recursively construct all
46 // subgraphs and then recursively cleanup & unmerge the small subgraphs
47 buildupSubgraphs();
48 GRAPH_DUMP("before unfuseAliasedOutputs", graph_);
49 unfuseAliasedOutputs(block_);
50 cleanupSubgraphs();
51 // Run CSE globally onceto eliminate duplicates that may have occurred
52 // while inlining subgraphs.
53 EliminateCommonSubexpression(graph_);
54 }
55
cleanupSubgraphs()56 void cleanupSubgraphs() {
57 auto curNode = *block_->nodes().rbegin();
58 while (curNode != *block_->nodes().rend()) {
59 // Save the previous node, since we might delete `curNode` in next block
60 auto prevNode = curNode->prev();
61 if (curNode->kind() == prim::DifferentiableGraph) {
62 // Inlining nodes may cause some subexpression to come back in the
63 // subgraphs (for example, copying constants in repeatedly will generate
64 // redundant prim::Constants). Run CSE to clean them up.
65 EliminateCommonSubexpression(curNode->g(attr::Subgraph));
66
67 if (!inlineIfTooSmall(curNode)) {
68 diff_nodes_.push_back(curNode);
69 }
70 }
71 curNode = prevNode;
72 }
73
74 for (Node* n : block_->nodes()) {
75 for (Block* b : n->blocks()) {
76 SubgraphSlicer(b, graph_, minSubgraphSize_, aliasDb_, diff_nodes_)
77 .cleanupSubgraphs();
78 }
79 }
80 }
81
buildupSubgraphs()82 void buildupSubgraphs() {
83 // We need to run the slicer multiple times in order to get all merge
84 // opportunities. This is because moveBeforeTopologicalValid may reorder
85 // nodes to be AFTER the current iteration point. In order to properly
86 // consider those nodes for merging, we need run the pass until no changes
87 // have been made.
88 //
89 // Example:
90 // c = f(a, b)
91 // d = f(c)
92 // e = f(d) <- iter is here, moving upward
93 // After c.moveBeforeTopologicallyValid(e), we have:
94 // c = f(a, b)
95 // e = f(d) <- iter still here
96 // d = f(c) <- this was node moved on the other side.
97
98 // see [workblocks]
99 auto workblocks = buildWorkBlocks();
100 for (auto& workblock : workblocks) {
101 bool any_changed = true;
102 while (any_changed) {
103 any_changed = false;
104 for (auto it = workblock.end()->reverseIterator();
105 it != workblock.begin()->reverseIterator();) {
106 auto [tmp_it, changed] = scanNode(*it);
107 it = tmp_it;
108 any_changed |= changed;
109 }
110 }
111 }
112
113 // Construct Subgraphs Recursively
114 for (Node* n : block_->nodes()) {
115 for (auto subBlock : n->blocks()) {
116 SubgraphSlicer(
117 subBlock, graph_, minSubgraphSize_, aliasDb_, diff_nodes_)
118 .buildupSubgraphs();
119 }
120 }
121 }
122
123 private:
unfuseAliasedOutputs(Block * b)124 void unfuseAliasedOutputs(Block* b) {
125 bool any_changed = true;
126 while (any_changed) {
127 any_changed = false;
128 // we walk in the reverse order, so we can skip
129 // nodes that might get unfused after the current
130 // prim::DifferentiableGraph
131 for (auto n : b->nodes().reverse()) {
132 if (n->kind() == prim::DifferentiableGraph) {
133 // aliased outputs in DifferentiableGraphs must be unfused
134 // since autodiff doesn't know how to handle them correctly
135 // N.B. Note, |= since we don't want `unfuseAliasedOutputs`
136 // to short-circuit
137 any_changed |= SubgraphUtils::unmergeAliasedOutputs(n);
138 any_changed |= SubgraphUtils::unmergeOutputsAlisingInputs(n);
139 GRAPH_DEBUG(
140 "any_changed on ",
141 any_changed,
142 " ",
143 n->g(attr::Subgraph)->toString(false));
144 }
145 }
146 }
147
148 for (Node* n : b->nodes()) {
149 for (Block* ib : n->blocks()) {
150 unfuseAliasedOutputs(ib);
151 }
152 }
153 }
154
buildWorkBlocks()155 std::vector<WorkBlock> buildWorkBlocks() {
156 // [workblocks]
157 // the IR has many nodes which can never be reordered around, such as a
158 // prim::Bailout. if a node N is surrounded by two nodes which cannot be
159 // reordered, A and B, then a differentiable subgraph that is created from N
160 // can only contain nodes from (A, B) The nodes from A to B represent one
161 // work block for the subgraph slicer to work on. By creating these up
162 // front, we avoid retraversing the whole graph block any time scanNode
163 // returns, and we can also avoid attempting to create differentiable
164 // subgraphs in work blocks that do not contain a # of differentiable nodes
165 // >= minSubgraphSize_
166
167 Node* end_bound_node = block_->return_node();
168 Node* curr = end_bound_node->prev();
169
170 std::vector<WorkBlock> worklist;
171 size_t differentiable_nodes = 0;
172
173 while (curr != block_->param_node()) {
174 differentiable_nodes += shouldConsiderForMerge(curr);
175
176 // cannot reorder around side effectful nodes
177 if (curr->hasSideEffects()) {
178 // not enough differentiable nodes to create a differentiable subgraph
179 if (differentiable_nodes >= minSubgraphSize_) {
180 worklist.emplace_back(curr, end_bound_node);
181 }
182 differentiable_nodes = 0;
183 end_bound_node = curr;
184 }
185 curr = curr->prev();
186 }
187
188 if (differentiable_nodes >= minSubgraphSize_) {
189 worklist.emplace_back(curr, end_bound_node);
190 }
191
192 return worklist;
193 }
194
195 // Inline this node's group subgraph into the outer graph if it's smaller
196 // than the specified minimum size.
197 //
198 // Returns true if an inlining has occurred, false otherwise.
inlineIfTooSmall(Node * n)199 bool inlineIfTooSmall(Node* n) {
200 AT_ASSERT(n->kind() == prim::DifferentiableGraph);
201 auto subgraph = SubgraphUtils::getSubgraph(n);
202 size_t i = 0;
203 for (auto it = subgraph->nodes().begin(); it != subgraph->nodes().end();
204 ++it) {
205 i += !it->notExecutedOp();
206 if (i >= minSubgraphSize_) {
207 return false;
208 }
209 }
210
211 SubgraphUtils::unmergeSubgraph(n);
212 return true;
213 }
214
sortReverseTopological(ArrayRef<Value * > inputs)215 value_list sortReverseTopological(ArrayRef<Value*> inputs) {
216 value_list result;
217 for (auto i : inputs) {
218 if (i->node()->owningBlock() == block_) {
219 result.push_back(i);
220 }
221 }
222 // Sort in reverse topological order
223 std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
224 return a->node()->isAfter(b->node());
225 });
226 return result;
227 }
228
isViewOp(Node * n)229 bool isViewOp(Node* n) {
230 switch (n->kind()) {
231 case aten::view:
232 case aten::view_as:
233 case aten::reshape:
234 case aten::reshape_as:
235 case aten::transpose:
236 case aten::expand:
237 case aten::expand_as:
238 return true;
239 }
240 return false;
241 }
242
shouldConsiderForMerge(Node * node)243 bool shouldConsiderForMerge(Node* node) {
244 // if we're already in the process of merging
245 if (node->kind() == prim::DifferentiableGraph) {
246 return true;
247 }
248 if (node->kind() == prim::Constant) {
249 return false;
250 }
251
252 // view ops as outputs of differentiable subgraphs can cause incorrect
253 // differentiation for now, do not include them in the subgraph
254 if (isViewOp(node)) {
255 return false;
256 }
257
258 return isDifferentiable(node);
259 }
260
scanNode(Node * consumer)261 std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
262 if (shouldConsiderForMerge(consumer)) {
263 if (consumer->kind() != prim::DifferentiableGraph) {
264 consumer = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
265 consumer, prim::DifferentiableGraph, aliasDb_);
266 }
267 auto inputs = sortReverseTopological(consumer->inputs());
268 for (auto input : inputs) {
269 if (auto group = tryMerge(consumer, input->node())) {
270 // we successfully merged, so the new group's `inputs` may have
271 // changed. So rescan the new group for more merging opportunities.
272 return std::make_pair(group.value()->reverseIterator(), true);
273 }
274 }
275 }
276
277 return std::make_pair(++consumer->reverseIterator(), false);
278 }
279
280 // Try to merge `producer` into `consumer`. If successful, this destroys
281 // `producer` and returns the `consumer` group.
tryMerge(Node * consumer,Node * producer)282 std::optional<Node*> tryMerge(Node* consumer, Node* producer) {
283 AT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
284 bool canMerge = shouldConsiderForMerge(producer) &&
285 aliasDb_.moveBeforeTopologicallyValid(producer, consumer);
286
287 if (!canMerge) {
288 return std::nullopt;
289 }
290
291 SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
292 producer, consumer, aliasDb_);
293 return consumer;
294 }
295
296 Block* block_;
297 std::shared_ptr<Graph> graph_;
298 size_t minSubgraphSize_;
299 AliasDb& aliasDb_;
300 std::vector<Node*>& diff_nodes_;
301 };
302
getProfileNodeRequiresGrad(Node * n)303 std::optional<bool> getProfileNodeRequiresGrad(Node* n) {
304 TORCH_INTERNAL_ASSERT(n->kind() == prim::profile);
305 if (!n->hasAttribute(attr::profiled_type)) {
306 return std::nullopt;
307 }
308 auto& type = n->ty(attr::profiled_type);
309 if (type->castRaw<TensorType>() == nullptr) {
310 return std::nullopt;
311 }
312 return type->expectRef<TensorType>().requiresGrad();
313 }
314
315 struct ContextMapping {
316 std::vector<const Node*> ctx_stack_;
317 std::unordered_map<const Node*, const Node*> node_to_ctx_;
318
processNodetorch::jit::__anon5a51c5e00111::ContextMapping319 void processNode(Node* n) {
320 node_to_ctx_[n] = ctx_stack_.back();
321
322 if (n->kind() == prim::Enter) {
323 ctx_stack_.push_back(n);
324 } else if (n->kind() == prim::Exit) {
325 ctx_stack_.pop_back();
326 }
327 }
328
processBlocktorch::jit::__anon5a51c5e00111::ContextMapping329 void processBlock(Block* block) {
330 for (Node* n : block->nodes()) {
331 processNode(n);
332 for (Block* b : n->blocks()) {
333 processBlock(b);
334 }
335 if (n->kind() == prim::DifferentiableGraph) {
336 const auto& subgraph = n->g(attr::Subgraph);
337 processBlock(subgraph->block());
338 }
339 }
340 }
341
ContextMappingtorch::jit::__anon5a51c5e00111::ContextMapping342 ContextMapping(const std::shared_ptr<Graph>& graph) {
343 ctx_stack_.push_back(nullptr);
344 processBlock(graph->block());
345 }
346
gettorch::jit::__anon5a51c5e00111::ContextMapping347 const Node* get(const Node* n) const {
348 auto it = node_to_ctx_.find(n);
349 TORCH_INTERNAL_ASSERT(
350 it != node_to_ctx_.end(),
351 "Cannot find node in node-to-context mapping.");
352 return it->second;
353 }
354
hastorch::jit::__anon5a51c5e00111::ContextMapping355 bool has(const Node* n) const {
356 return node_to_ctx_.find(n) != node_to_ctx_.end();
357 }
358 };
359
findRequiresGradForOutput(Node * diff_graph,Value * output,const ContextMapping & ctx_mapping)360 std::optional<bool> findRequiresGradForOutput(
361 Node* diff_graph,
362 Value* output,
363 const ContextMapping& ctx_mapping) {
364 for (auto& use : output->uses()) {
365 // [Only consider profiles in the same context]
366 // Ignore profiled uses if the use is within a different context.
367 // For example, a profile node within a no_grad() context will record the
368 // wrong requires_grad information.
369 if (ctx_mapping.has(use.user) &&
370 ctx_mapping.get(use.user) != ctx_mapping.get(diff_graph)) {
371 continue;
372 }
373
374 if (use.user->kind() == prim::profile) {
375 auto req_grad_use = getProfileNodeRequiresGrad(use.user);
376 if (req_grad_use.has_value()) {
377 return req_grad_use;
378 }
379 }
380
381 // maybe the profile node got absorbed into a differentiable graph
382 if (use.user->kind() == prim::DifferentiableGraph) {
383 const auto& dg = use.user->g(attr::Subgraph);
384 // check all the uses of this graph input to look for profile nodes.
385 Value* dg_value = dg->inputs()[use.offset];
386 for (auto& dg_use : dg_value->uses()) {
387 // See [Only consider profiles in the same context]
388 if (ctx_mapping.has(dg_use.user) &&
389 ctx_mapping.get(dg_use.user) != ctx_mapping.get(diff_graph)) {
390 continue;
391 }
392
393 if (dg_use.user->kind() == prim::profile) {
394 auto req_grad_use = getProfileNodeRequiresGrad(dg_use.user);
395 if (req_grad_use.has_value()) {
396 return req_grad_use;
397 }
398 }
399 }
400 }
401 }
402
403 return std::nullopt;
404 }
405
AddRequiresGradToDifferentiableGraph(Node * diff_graph,const ContextMapping & ctx_mapping)406 void AddRequiresGradToDifferentiableGraph(
407 Node* diff_graph,
408 const ContextMapping& ctx_mapping) {
409 TORCH_INTERNAL_ASSERT(diff_graph->kind() == prim::DifferentiableGraph);
410 const auto& subgraph = diff_graph->g(attr::Subgraph);
411 for (auto i : c10::irange(subgraph->outputs().size())) {
412 Value* output = subgraph->outputs()[i];
413 if (output->node()->kind() == prim::profile) {
414 // already have requires_grad info from this profile node
415 continue;
416 }
417 if (output->type()->castRaw<TensorType>() == nullptr) {
418 // non-tensors don't get profiled.
419 continue;
420 }
421 if (output->type()->expectRef<TensorType>().requiresGrad().has_value()) {
422 continue;
423 }
424
425 // this node doesn't have any requires_grad info.
426 // look at its uses to try to find a profile node.
427 auto requires_grad = findRequiresGradForOutput(
428 diff_graph, diff_graph->output(i), ctx_mapping);
429
430 output->setType(output->type()->expectRef<TensorType>().withRequiresGrad(
431 requires_grad));
432 }
433 }
434
AddRequiresGradOnOutputNodes(Block * block,const ContextMapping & ctx_mapping)435 void AddRequiresGradOnOutputNodes(
436 Block* block,
437 const ContextMapping& ctx_mapping) {
438 for (Node* n : block->nodes()) {
439 if (n->kind() == prim::DifferentiableGraph) {
440 AddRequiresGradToDifferentiableGraph(n, ctx_mapping);
441 }
442 for (Block* b : n->blocks()) {
443 AddRequiresGradOnOutputNodes(b, ctx_mapping);
444 }
445 }
446 }
447
448 // autodiff.cpp needs to know, for each output, whether or not it requires
449 // grad. Sometimes a profile node will be present on the output, but sometimes
450 // it won't be present. This might happen if there's a node with side effects
451 // in between the definition of the output node and the profile node; in this
452 // case the profile node and output node would be in different workblocks and
453 // couldn't be merged into the same DifferentiableGraph. (see [workblocks])
454 // Or it could happen if the output is profiled twice and the profile nodes get
455 // removed by unfusedAliasedOutputs.
AddRequiresGradOnOutputNodes(const std::shared_ptr<Graph> & graph)456 void AddRequiresGradOnOutputNodes(const std::shared_ptr<Graph>& graph) {
457 ContextMapping ctx_mapping(graph);
458 AddRequiresGradOnOutputNodes(graph->block(), ctx_mapping);
459 }
460 } // anonymous namespace
461
CreateAutodiffSubgraphs(const std::shared_ptr<Graph> & graph,size_t threshold)462 std::vector<Node*> CreateAutodiffSubgraphs(
463 const std::shared_ptr<Graph>& graph,
464 size_t threshold) {
465 std::vector<Node*> diff_nodes;
466 AliasDb db(graph);
467 GRAPH_DEBUG("Before creating autodiff subgraphs", *graph);
468 SubgraphSlicer(graph->block(), graph, threshold, db, diff_nodes).run();
469 GRAPH_DEBUG("After creating autodiff subgraphs", *graph);
470 AddRequiresGradOnOutputNodes(graph);
471 GRAPH_DEBUG("diff_nodes.size() ", diff_nodes.size());
472 return diff_nodes;
473 }
474 } // namespace torch::jit
475