xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/passes/canonicalize.h>
8 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
9 #include <torch/csrc/jit/passes/remove_redundant_profiles.h>
10 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
11 #include <torch/csrc/jit/runtime/autodiff.h>
12 
13 namespace torch::jit {
14 
15 namespace {
16 
17 struct WorkBlock : public std::pair<Node*, Node*> {
18   using pair::pair;
19 
begintorch::jit::__anon5a51c5e00111::WorkBlock20   Node* begin() {
21     return this->first;
22   }
endtorch::jit::__anon5a51c5e00111::WorkBlock23   Node* end() {
24     return this->second;
25   }
26 };
27 
28 class SubgraphSlicer {
29  public:
SubgraphSlicer(Block * block,std::shared_ptr<Graph> graph,size_t minSubgraphSize,AliasDb & aliasDb,std::vector<Node * > & diff_nodes)30   SubgraphSlicer(
31       Block* block,
32       std::shared_ptr<Graph> graph,
33       size_t minSubgraphSize,
34       AliasDb& aliasDb,
35       std::vector<Node*>& diff_nodes)
36       : block_(block),
37         graph_(std::move(graph)),
38         minSubgraphSize_(minSubgraphSize),
39         aliasDb_(aliasDb),
40         diff_nodes_(diff_nodes) {}
41 
run()42   void run() {
43     // We maintain alias db correctness in-place while building up the autodiff
44     // subgraphs, however it is difficult to preserve correctness when
45     // un-inlining autodiff subgraphs. We first recursively construct all
46     // subgraphs and then recursively cleanup & unmerge the small subgraphs
47     buildupSubgraphs();
48     GRAPH_DUMP("before unfuseAliasedOutputs", graph_);
49     unfuseAliasedOutputs(block_);
50     cleanupSubgraphs();
51     // Run CSE globally onceto eliminate duplicates that may have occurred
52     // while inlining subgraphs.
53     EliminateCommonSubexpression(graph_);
54   }
55 
cleanupSubgraphs()56   void cleanupSubgraphs() {
57     auto curNode = *block_->nodes().rbegin();
58     while (curNode != *block_->nodes().rend()) {
59       // Save the previous node, since we might delete `curNode` in next block
60       auto prevNode = curNode->prev();
61       if (curNode->kind() == prim::DifferentiableGraph) {
62         // Inlining nodes may cause some subexpression to come back in the
63         // subgraphs (for example, copying constants in repeatedly will generate
64         // redundant prim::Constants). Run CSE to clean them up.
65         EliminateCommonSubexpression(curNode->g(attr::Subgraph));
66 
67         if (!inlineIfTooSmall(curNode)) {
68           diff_nodes_.push_back(curNode);
69         }
70       }
71       curNode = prevNode;
72     }
73 
74     for (Node* n : block_->nodes()) {
75       for (Block* b : n->blocks()) {
76         SubgraphSlicer(b, graph_, minSubgraphSize_, aliasDb_, diff_nodes_)
77             .cleanupSubgraphs();
78       }
79     }
80   }
81 
buildupSubgraphs()82   void buildupSubgraphs() {
83     // We need to run the slicer multiple times in order to get all merge
84     // opportunities. This is because moveBeforeTopologicalValid may reorder
85     // nodes to be AFTER the current iteration point. In order to properly
86     // consider those nodes for merging, we need run the pass until no changes
87     // have been made.
88     //
89     // Example:
90     //   c = f(a, b)
91     //   d = f(c)
92     //   e = f(d)  <- iter is here, moving upward
93     // After c.moveBeforeTopologicallyValid(e), we have:
94     //   c = f(a, b)
95     //   e = f(d)  <- iter still here
96     //   d = f(c)  <- this was node moved on the other side.
97 
98     // see [workblocks]
99     auto workblocks = buildWorkBlocks();
100     for (auto& workblock : workblocks) {
101       bool any_changed = true;
102       while (any_changed) {
103         any_changed = false;
104         for (auto it = workblock.end()->reverseIterator();
105              it != workblock.begin()->reverseIterator();) {
106           auto [tmp_it, changed] = scanNode(*it);
107           it = tmp_it;
108           any_changed |= changed;
109         }
110       }
111     }
112 
113     // Construct Subgraphs Recursively
114     for (Node* n : block_->nodes()) {
115       for (auto subBlock : n->blocks()) {
116         SubgraphSlicer(
117             subBlock, graph_, minSubgraphSize_, aliasDb_, diff_nodes_)
118             .buildupSubgraphs();
119       }
120     }
121   }
122 
123  private:
unfuseAliasedOutputs(Block * b)124   void unfuseAliasedOutputs(Block* b) {
125     bool any_changed = true;
126     while (any_changed) {
127       any_changed = false;
128       // we walk in the reverse order, so we can skip
129       // nodes that might get unfused after the current
130       // prim::DifferentiableGraph
131       for (auto n : b->nodes().reverse()) {
132         if (n->kind() == prim::DifferentiableGraph) {
133           // aliased outputs in DifferentiableGraphs must be unfused
134           // since autodiff doesn't know how to handle them correctly
135           // N.B. Note, |= since we don't want `unfuseAliasedOutputs`
136           // to short-circuit
137           any_changed |= SubgraphUtils::unmergeAliasedOutputs(n);
138           any_changed |= SubgraphUtils::unmergeOutputsAlisingInputs(n);
139           GRAPH_DEBUG(
140               "any_changed on ",
141               any_changed,
142               " ",
143               n->g(attr::Subgraph)->toString(false));
144         }
145       }
146     }
147 
148     for (Node* n : b->nodes()) {
149       for (Block* ib : n->blocks()) {
150         unfuseAliasedOutputs(ib);
151       }
152     }
153   }
154 
buildWorkBlocks()155   std::vector<WorkBlock> buildWorkBlocks() {
156     // [workblocks]
157     // the IR has many nodes which can never be reordered around, such as a
158     // prim::Bailout. if a node N is surrounded by two nodes which cannot be
159     // reordered, A and B, then a differentiable subgraph that is created from N
160     // can only contain nodes from (A, B) The nodes from A to B represent one
161     // work block for the subgraph slicer to work on. By creating these up
162     // front, we avoid retraversing the whole graph block any time scanNode
163     // returns, and we can also avoid attempting to create differentiable
164     // subgraphs in work blocks that do not contain a # of differentiable nodes
165     // >= minSubgraphSize_
166 
167     Node* end_bound_node = block_->return_node();
168     Node* curr = end_bound_node->prev();
169 
170     std::vector<WorkBlock> worklist;
171     size_t differentiable_nodes = 0;
172 
173     while (curr != block_->param_node()) {
174       differentiable_nodes += shouldConsiderForMerge(curr);
175 
176       // cannot reorder around side effectful nodes
177       if (curr->hasSideEffects()) {
178         // not enough differentiable nodes to create a differentiable subgraph
179         if (differentiable_nodes >= minSubgraphSize_) {
180           worklist.emplace_back(curr, end_bound_node);
181         }
182         differentiable_nodes = 0;
183         end_bound_node = curr;
184       }
185       curr = curr->prev();
186     }
187 
188     if (differentiable_nodes >= minSubgraphSize_) {
189       worklist.emplace_back(curr, end_bound_node);
190     }
191 
192     return worklist;
193   }
194 
195   // Inline this node's group subgraph into the outer graph if it's smaller
196   // than the specified minimum size.
197   //
198   // Returns true if an inlining has occurred, false otherwise.
inlineIfTooSmall(Node * n)199   bool inlineIfTooSmall(Node* n) {
200     AT_ASSERT(n->kind() == prim::DifferentiableGraph);
201     auto subgraph = SubgraphUtils::getSubgraph(n);
202     size_t i = 0;
203     for (auto it = subgraph->nodes().begin(); it != subgraph->nodes().end();
204          ++it) {
205       i += !it->notExecutedOp();
206       if (i >= minSubgraphSize_) {
207         return false;
208       }
209     }
210 
211     SubgraphUtils::unmergeSubgraph(n);
212     return true;
213   }
214 
sortReverseTopological(ArrayRef<Value * > inputs)215   value_list sortReverseTopological(ArrayRef<Value*> inputs) {
216     value_list result;
217     for (auto i : inputs) {
218       if (i->node()->owningBlock() == block_) {
219         result.push_back(i);
220       }
221     }
222     // Sort in reverse topological order
223     std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
224       return a->node()->isAfter(b->node());
225     });
226     return result;
227   }
228 
isViewOp(Node * n)229   bool isViewOp(Node* n) {
230     switch (n->kind()) {
231       case aten::view:
232       case aten::view_as:
233       case aten::reshape:
234       case aten::reshape_as:
235       case aten::transpose:
236       case aten::expand:
237       case aten::expand_as:
238         return true;
239     }
240     return false;
241   }
242 
shouldConsiderForMerge(Node * node)243   bool shouldConsiderForMerge(Node* node) {
244     // if we're already in the process of merging
245     if (node->kind() == prim::DifferentiableGraph) {
246       return true;
247     }
248     if (node->kind() == prim::Constant) {
249       return false;
250     }
251 
252     // view ops as outputs of differentiable subgraphs can cause incorrect
253     // differentiation for now, do not include them in the subgraph
254     if (isViewOp(node)) {
255       return false;
256     }
257 
258     return isDifferentiable(node);
259   }
260 
scanNode(Node * consumer)261   std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
262     if (shouldConsiderForMerge(consumer)) {
263       if (consumer->kind() != prim::DifferentiableGraph) {
264         consumer = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
265             consumer, prim::DifferentiableGraph, aliasDb_);
266       }
267       auto inputs = sortReverseTopological(consumer->inputs());
268       for (auto input : inputs) {
269         if (auto group = tryMerge(consumer, input->node())) {
270           // we successfully merged, so the new group's `inputs` may have
271           // changed. So rescan the new group for more merging opportunities.
272           return std::make_pair(group.value()->reverseIterator(), true);
273         }
274       }
275     }
276 
277     return std::make_pair(++consumer->reverseIterator(), false);
278   }
279 
280   // Try to merge `producer` into `consumer`. If successful, this destroys
281   // `producer` and returns the `consumer` group.
tryMerge(Node * consumer,Node * producer)282   std::optional<Node*> tryMerge(Node* consumer, Node* producer) {
283     AT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
284     bool canMerge = shouldConsiderForMerge(producer) &&
285         aliasDb_.moveBeforeTopologicallyValid(producer, consumer);
286 
287     if (!canMerge) {
288       return std::nullopt;
289     }
290 
291     SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
292         producer, consumer, aliasDb_);
293     return consumer;
294   }
295 
296   Block* block_;
297   std::shared_ptr<Graph> graph_;
298   size_t minSubgraphSize_;
299   AliasDb& aliasDb_;
300   std::vector<Node*>& diff_nodes_;
301 };
302 
getProfileNodeRequiresGrad(Node * n)303 std::optional<bool> getProfileNodeRequiresGrad(Node* n) {
304   TORCH_INTERNAL_ASSERT(n->kind() == prim::profile);
305   if (!n->hasAttribute(attr::profiled_type)) {
306     return std::nullopt;
307   }
308   auto& type = n->ty(attr::profiled_type);
309   if (type->castRaw<TensorType>() == nullptr) {
310     return std::nullopt;
311   }
312   return type->expectRef<TensorType>().requiresGrad();
313 }
314 
315 struct ContextMapping {
316   std::vector<const Node*> ctx_stack_;
317   std::unordered_map<const Node*, const Node*> node_to_ctx_;
318 
processNodetorch::jit::__anon5a51c5e00111::ContextMapping319   void processNode(Node* n) {
320     node_to_ctx_[n] = ctx_stack_.back();
321 
322     if (n->kind() == prim::Enter) {
323       ctx_stack_.push_back(n);
324     } else if (n->kind() == prim::Exit) {
325       ctx_stack_.pop_back();
326     }
327   }
328 
processBlocktorch::jit::__anon5a51c5e00111::ContextMapping329   void processBlock(Block* block) {
330     for (Node* n : block->nodes()) {
331       processNode(n);
332       for (Block* b : n->blocks()) {
333         processBlock(b);
334       }
335       if (n->kind() == prim::DifferentiableGraph) {
336         const auto& subgraph = n->g(attr::Subgraph);
337         processBlock(subgraph->block());
338       }
339     }
340   }
341 
ContextMappingtorch::jit::__anon5a51c5e00111::ContextMapping342   ContextMapping(const std::shared_ptr<Graph>& graph) {
343     ctx_stack_.push_back(nullptr);
344     processBlock(graph->block());
345   }
346 
gettorch::jit::__anon5a51c5e00111::ContextMapping347   const Node* get(const Node* n) const {
348     auto it = node_to_ctx_.find(n);
349     TORCH_INTERNAL_ASSERT(
350         it != node_to_ctx_.end(),
351         "Cannot find node in node-to-context mapping.");
352     return it->second;
353   }
354 
hastorch::jit::__anon5a51c5e00111::ContextMapping355   bool has(const Node* n) const {
356     return node_to_ctx_.find(n) != node_to_ctx_.end();
357   }
358 };
359 
findRequiresGradForOutput(Node * diff_graph,Value * output,const ContextMapping & ctx_mapping)360 std::optional<bool> findRequiresGradForOutput(
361     Node* diff_graph,
362     Value* output,
363     const ContextMapping& ctx_mapping) {
364   for (auto& use : output->uses()) {
365     // [Only consider profiles in the same context]
366     // Ignore profiled uses if the use is within a different context.
367     // For example, a profile node within a no_grad() context will record the
368     // wrong requires_grad information.
369     if (ctx_mapping.has(use.user) &&
370         ctx_mapping.get(use.user) != ctx_mapping.get(diff_graph)) {
371       continue;
372     }
373 
374     if (use.user->kind() == prim::profile) {
375       auto req_grad_use = getProfileNodeRequiresGrad(use.user);
376       if (req_grad_use.has_value()) {
377         return req_grad_use;
378       }
379     }
380 
381     // maybe the profile node got absorbed into a differentiable graph
382     if (use.user->kind() == prim::DifferentiableGraph) {
383       const auto& dg = use.user->g(attr::Subgraph);
384       // check all the uses of this graph input to look for profile nodes.
385       Value* dg_value = dg->inputs()[use.offset];
386       for (auto& dg_use : dg_value->uses()) {
387         // See [Only consider profiles in the same context]
388         if (ctx_mapping.has(dg_use.user) &&
389             ctx_mapping.get(dg_use.user) != ctx_mapping.get(diff_graph)) {
390           continue;
391         }
392 
393         if (dg_use.user->kind() == prim::profile) {
394           auto req_grad_use = getProfileNodeRequiresGrad(dg_use.user);
395           if (req_grad_use.has_value()) {
396             return req_grad_use;
397           }
398         }
399       }
400     }
401   }
402 
403   return std::nullopt;
404 }
405 
AddRequiresGradToDifferentiableGraph(Node * diff_graph,const ContextMapping & ctx_mapping)406 void AddRequiresGradToDifferentiableGraph(
407     Node* diff_graph,
408     const ContextMapping& ctx_mapping) {
409   TORCH_INTERNAL_ASSERT(diff_graph->kind() == prim::DifferentiableGraph);
410   const auto& subgraph = diff_graph->g(attr::Subgraph);
411   for (auto i : c10::irange(subgraph->outputs().size())) {
412     Value* output = subgraph->outputs()[i];
413     if (output->node()->kind() == prim::profile) {
414       // already have requires_grad info from this profile node
415       continue;
416     }
417     if (output->type()->castRaw<TensorType>() == nullptr) {
418       // non-tensors don't get profiled.
419       continue;
420     }
421     if (output->type()->expectRef<TensorType>().requiresGrad().has_value()) {
422       continue;
423     }
424 
425     // this node doesn't have any requires_grad info.
426     // look at its uses to try to find a profile node.
427     auto requires_grad = findRequiresGradForOutput(
428         diff_graph, diff_graph->output(i), ctx_mapping);
429 
430     output->setType(output->type()->expectRef<TensorType>().withRequiresGrad(
431         requires_grad));
432   }
433 }
434 
AddRequiresGradOnOutputNodes(Block * block,const ContextMapping & ctx_mapping)435 void AddRequiresGradOnOutputNodes(
436     Block* block,
437     const ContextMapping& ctx_mapping) {
438   for (Node* n : block->nodes()) {
439     if (n->kind() == prim::DifferentiableGraph) {
440       AddRequiresGradToDifferentiableGraph(n, ctx_mapping);
441     }
442     for (Block* b : n->blocks()) {
443       AddRequiresGradOnOutputNodes(b, ctx_mapping);
444     }
445   }
446 }
447 
448 // autodiff.cpp needs to know, for each output, whether or not it requires
449 // grad. Sometimes a profile node will be present on the output, but sometimes
450 // it won't be present. This might happen if there's a node with side effects
451 // in between the definition of the output node and the profile node; in this
452 // case the profile node and output node would be in different workblocks and
453 // couldn't be merged into the same DifferentiableGraph. (see [workblocks])
454 // Or it could happen if the output is profiled twice and the profile nodes get
455 // removed by unfusedAliasedOutputs.
AddRequiresGradOnOutputNodes(const std::shared_ptr<Graph> & graph)456 void AddRequiresGradOnOutputNodes(const std::shared_ptr<Graph>& graph) {
457   ContextMapping ctx_mapping(graph);
458   AddRequiresGradOnOutputNodes(graph->block(), ctx_mapping);
459 }
460 } // anonymous namespace
461 
CreateAutodiffSubgraphs(const std::shared_ptr<Graph> & graph,size_t threshold)462 std::vector<Node*> CreateAutodiffSubgraphs(
463     const std::shared_ptr<Graph>& graph,
464     size_t threshold) {
465   std::vector<Node*> diff_nodes;
466   AliasDb db(graph);
467   GRAPH_DEBUG("Before creating autodiff subgraphs", *graph);
468   SubgraphSlicer(graph->block(), graph, threshold, db, diff_nodes).run();
469   GRAPH_DEBUG("After creating autodiff subgraphs", *graph);
470   AddRequiresGradOnOutputNodes(graph);
471   GRAPH_DEBUG("diff_nodes.size() ", diff_nodes.size());
472   return diff_nodes;
473 }
474 } // namespace torch::jit
475