xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/fusion.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/static/fusion.h>
2 
3 #include <ATen/core/symbol.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/jit/passes/canonicalize.h>
7 #include <torch/csrc/jit/passes/constant_pooling.h>
8 #include <torch/csrc/jit/passes/dead_code_elimination.h>
9 #include <torch/csrc/jit/passes/freeze_module.h>
10 #include <torch/csrc/jit/passes/remove_mutation.h>
11 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
12 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
13 #include <torch/csrc/jit/runtime/custom_operator.h>
14 #include <torch/csrc/jit/runtime/graph_iterator.h>
15 #include <torch/csrc/jit/runtime/jit_trace.h>
16 #include <torch/csrc/jit/runtime/static/impl.h>
17 #include <torch/csrc/jit/runtime/static/ops.h>
18 #include <torch/csrc/jit/runtime/static/passes.h>
19 
20 namespace torch::jit {
21 
22 void createFusionGroups(Block* block, AliasDb* aliasDb, size_t min_size);
23 
fuseStaticSubgraphs(std::shared_ptr<Graph> graph,size_t min_size)24 void fuseStaticSubgraphs(std::shared_ptr<Graph> graph, size_t min_size) {
25   Inline(*graph);
26   ReplaceWithCopy(graph);
27   ReplaceWithMaybeCopy(graph);
28   ConstantPropagation(graph);
29   Canonicalize(graph);
30   ConstantPropagation(graph);
31   RemoveTensorMutation(graph);
32   ConstantPropagation(graph);
33   EliminateDeadCode(graph);
34   auto aliasDb = std::make_unique<AliasDb>(graph);
35   createFusionGroups(graph->block(), aliasDb.get(), min_size);
36   ConstantPooling(graph);
37   ConstantPropagation(graph);
38   torch::jit::EliminateDeadCode(graph);
39 }
40 
createStaticSubgraphRuntime(const Node * node)41 static Operation createStaticSubgraphRuntime(const Node* node) {
42   auto g = node->g(attr::Subgraph);
43   auto module = std::make_shared<torch::jit::StaticModule>(g);
44   auto num_inputs = module->num_inputs();
45   return [module, num_inputs](Stack& stack) {
46     RECORD_FUNCTION("Static Runtime", std::vector<c10::IValue>());
47     auto inps = torch::jit::last(stack, num_inputs);
48     // TODO maybe avoid call to vec
49     auto outputs = (*module)(inps.vec(), {});
50     torch::jit::drop(stack, num_inputs);
51 
52     if (module->num_outputs() > 1) {
53       for (auto& o : outputs.toTupleRef().elements()) {
54         push_one(stack, std::move(o));
55       }
56     } else {
57       push_one(stack, std::move(outputs));
58     }
59     return 0;
60   };
61 }
62 
63 RegisterOperators StaticSubgraphOps({torch::jit::Operator(
64     prim::StaticSubgraph,
65     createStaticSubgraphRuntime,
66     AliasAnalysisKind::INTERNAL_SPECIAL_CASE)});
67 
68 #define REQ(cond)                           \
69   if (!(cond)) {                            \
70     GRAPH_DEBUG("Failed cond " #cond "\n"); \
71     return false;                           \
72   }
73 
canHandle(Node * node)74 static bool canHandle(Node* node) {
75   for (Value* input : node->inputs()) {
76     bool is_tensor = !!input->type()->cast<TensorType>();
77     auto list_type = input->type()->cast<ListType>();
78     bool is_list = list_type && list_type->getElementType()->cast<TupleType>();
79     auto tuple_type = input->type()->cast<TupleType>();
80     bool is_tuple = [&]() -> bool {
81       if (!tuple_type) {
82         return false;
83       }
84       for (auto& t : tuple_type->elements()) {
85         if (!t->cast<TensorType>()) {
86           return false;
87         }
88       }
89       return true;
90     }();
91     if (!(is_tensor || is_list || is_tuple)) {
92       if (input->node()->kind() != prim::Constant) {
93         return false;
94       }
95     }
96   }
97 
98   auto kind = node->kind();
99   if (kind.is_prim()) {
100     REQ(kind == prim::TupleConstruct || kind == prim::ListConstruct ||
101         kind == prim::StaticSubgraph);
102     if (kind == prim::TupleConstruct || kind == prim::ListConstruct) {
103       for (Value* input : node->inputs()) {
104         if (!input->type()->cast<TensorType>()) {
105           return false;
106         }
107       }
108     }
109     return true;
110   }
111 
112   // TODO add "canRunNatively" once memory management is audited
113   return getOutOfPlaceOperation(node) != nullptr;
114 }
115 
canMerge(Node * consumer,Node * producer,AliasDb * aliasDb)116 static bool canMerge(Node* consumer, Node* producer, AliasDb* aliasDb) {
117   // Only fuse within a block
118   REQ(consumer->owningBlock() == producer->owningBlock());
119 
120   // Symbolic checks
121   REQ(canHandle(producer) || producer->kind() == prim::StaticSubgraph);
122   TORCH_INTERNAL_ASSERT(
123       consumer->kind() == prim::StaticSubgraph || canHandle(consumer));
124 
125   // Alias checks
126   REQ(aliasDb->couldMoveBeforeTopologically(producer, consumer));
127 
128   // Ops that return aliases can only be folded if this is the only use.
129   if (producer->kind() == aten::slice || producer->kind() == aten::unsqueeze ||
130       producer->kind() == prim::ConstantChunk) {
131     for (auto& use : producer->output(0)->uses()) {
132       REQ(use.user == consumer);
133     }
134   }
135 
136   return true;
137 }
138 
getOrCreateStaticSubgraph(Node * n,AliasDb * aliasDb)139 static Node* getOrCreateStaticSubgraph(Node* n, AliasDb* aliasDb) {
140   if (n->hasAttribute(attr::Subgraph) && n->kind() == prim::StaticSubgraph) {
141     return n;
142   }
143   GRAPH_UPDATE("Creating a static subgraph::Group node from: ", *n);
144   return SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
145       n, prim::StaticSubgraph, *aliasDb);
146 }
147 
sortReverseTopological(ArrayRef<Value * > inputs,Block * b)148 static value_list sortReverseTopological(ArrayRef<Value*> inputs, Block* b) {
149   value_list result;
150   for (auto i : inputs) {
151     if (i->node()->owningBlock() == b) {
152       result.push_back(i);
153     }
154   }
155   // Sort in reverse topological order
156   std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
157     return a->node()->isAfter(b->node());
158   });
159   return result;
160 }
161 
debugDumpFusionGroup(const std::string & msg,Node * n)162 static void debugDumpFusionGroup(const std::string& msg, Node* n) {
163   GRAPH_DEBUG(msg, *n);
164   if (n->kind() == prim::StaticSubgraph) {
165     GRAPH_DEBUG(*n->g(attr::Subgraph));
166   }
167 }
168 
tryMerge(Node * fusion_group,Node * to_merge,AliasDb * aliasDb)169 static std::optional<Node*> tryMerge(
170     Node* fusion_group,
171     Node* to_merge,
172     AliasDb* aliasDb) {
173   if (!canMerge(fusion_group, to_merge, aliasDb)) {
174     return std::nullopt;
175   }
176 
177   std::vector<Node*> nodes_to_merge = {to_merge};
178 
179   if (to_merge->kind() == aten::cat) {
180     Node* listconstruct = to_merge->input(0)->node();
181     nodes_to_merge.push_back(listconstruct);
182   }
183 
184   // First, try to move all the nodes we want to fuse next to the fusion
185   // group.
186   Node* move_point = fusion_group;
187   for (auto n : nodes_to_merge) {
188     GRAPH_UPDATE("Trying to move node next to fusion group: ", getHeader(n));
189     if (!aliasDb->moveBeforeTopologicallyValid(n, move_point)) {
190       GRAPH_UPDATE("Failed to move because of AliasDb checks!");
191       return std::nullopt;
192     }
193     move_point = n;
194   }
195 
196   // Now all the nodes that we're going to fuse are moved next to the fusion
197   // group, so we can safely merge them into the fusion group subgraph.
198   fusion_group = getOrCreateStaticSubgraph(fusion_group, aliasDb);
199 
200   for (auto n : nodes_to_merge) {
201     GRAPH_UPDATE("Merging ", getHeader(n));
202     SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
203         n, fusion_group, *aliasDb);
204   }
205   return fusion_group;
206 }
207 
createFusionGroup(Node * fusion_node,AliasDb * aliasDb)208 static std::pair<graph_node_list::iterator, bool> createFusionGroup(
209     Node* fusion_node,
210     AliasDb* aliasDb) {
211   fusion_node = getOrCreateStaticSubgraph(fusion_node, aliasDb);
212 
213   GRAPH_DEBUG("Iteratively pull input nodes into the fusion group...\n");
214   auto inputs =
215       sortReverseTopological(fusion_node->inputs(), fusion_node->owningBlock());
216   for (auto input : inputs) {
217     debugDumpFusionGroup("Current fusion group: ", fusion_node);
218     GRAPH_DEBUG("Trying to merge: ", *input->node());
219     if (auto maybe_fusion_group =
220             tryMerge(fusion_node, input->node(), aliasDb)) {
221       // we successfully merged, so the new group's `inputs` may have
222       // changed. So rescan the new group for more merging opportunities.
223       return std::make_pair(
224           maybe_fusion_group.value()->reverseIterator(), true);
225     }
226   }
227 
228   return std::make_pair(++fusion_node->reverseIterator(), false);
229 }
230 
scanNode(Node * n,AliasDb * aliasDb)231 static std::pair<graph_node_list::iterator, bool> scanNode(
232     Node* n,
233     AliasDb* aliasDb) {
234   GRAPH_DEBUG("Considering node:", *n);
235 
236   if (!canHandle(n)) {
237     return std::make_pair(++n->reverseIterator(), false);
238   }
239 
240   return createFusionGroup(n, aliasDb);
241 }
242 
inlineIfTooSmall(Node * n,size_t min_size)243 static bool inlineIfTooSmall(Node* n, size_t min_size) {
244   if (n->kind() != prim::StaticSubgraph) {
245     return false;
246   }
247   auto subgraph = SubgraphUtils::getSubgraph(n);
248   size_t num_nodes = std::distance(
249       subgraph->block()->nodes().begin(), subgraph->block()->nodes().end());
250   if (num_nodes < min_size) {
251     GRAPH_UPDATE("Fusion group is too small, unmerging: ", *n);
252     SubgraphUtils::unmergeSubgraph(n);
253     return true;
254   }
255   ConstantPooling(subgraph);
256   ConstantPropagation(subgraph);
257   return false;
258 }
259 
inlineSmallFusionGroups(Block * block,size_t min_size)260 static void inlineSmallFusionGroups(Block* block, size_t min_size) {
261   for (Node* n : block->nodes()) {
262     for (Block* b : n->blocks()) {
263       inlineSmallFusionGroups(b, min_size);
264     }
265     inlineIfTooSmall(n, min_size);
266   }
267 }
268 
createFusionGroups(Block * block,AliasDb * aliasDb,size_t min_size)269 void createFusionGroups(Block* block, AliasDb* aliasDb, size_t min_size) {
270   bool any_changed = true;
271   while (any_changed) {
272     any_changed = false;
273     for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) {
274       bool changed = false;
275       std::tie(it, changed) = scanNode(*it, aliasDb);
276       any_changed |= changed;
277     }
278   }
279 
280   for (Node* n : block->nodes()) {
281     for (Block* b : n->blocks()) {
282       createFusionGroups(b, aliasDb, min_size);
283     }
284   }
285 
286   // Try to merge adjacent fusion groups together. Because we have only merged
287   // by looking at graph inputs, without this we would not attempt to merge
288   // adjacent fusion groups that don't have a dependency on each other
289 
290   std::vector<Node*> initial_fusion_groups;
291   for (Node* n : block->nodes()) {
292     if (n->kind() == prim::StaticSubgraph) {
293       initial_fusion_groups.push_back(n);
294     }
295   }
296 
297   Node* prev_fusion_group =
298       !initial_fusion_groups.empty() ? initial_fusion_groups[0] : nullptr;
299 
300   for (const auto i : c10::irange(1, initial_fusion_groups.size())) {
301     // Try merging the just created fusion group into the previous one.
302     // If it did not work, then put the previous fusion group into
303     // fusion_groups vector - we will not touch it anymore in this loop.
304     // If merging succeeded, save the merged group as the "previous" fusion
305     // group so that we can try to merge the next one into it.
306 
307     Node* fusion_group = initial_fusion_groups[i];
308     debugDumpFusionGroup(
309         "Trying to merge into the previous fusion group: ", prev_fusion_group);
310     if (auto merged_fusion_group =
311             tryMerge(prev_fusion_group, fusion_group, aliasDb)) {
312       prev_fusion_group = *merged_fusion_group;
313       debugDumpFusionGroup(
314           "Successfully merged into the previous fusion group: ",
315           prev_fusion_group);
316     } else {
317       GRAPH_DEBUG("Cannot merge into the previous fusion group");
318       prev_fusion_group = fusion_group;
319     }
320   }
321   inlineSmallFusionGroups(block, min_size);
322 }
323 
inlineFallbackGraphs(std::shared_ptr<Graph> graph)324 static void inlineFallbackGraphs(std::shared_ptr<Graph> graph) {
325   DepthFirstGraphNodeIterator it(graph);
326 
327   Node* n = nullptr;
328   while ((n = it.next()) != nullptr) {
329     if (n->kind() == prim::FallbackGraph) {
330       SubgraphUtils::unmergeSubgraph(n);
331     }
332   }
333 }
334 
performTensorExprFusion(std::shared_ptr<Graph> graph,std::vector<IValue> sample_inputs)335 void performTensorExprFusion(
336     std::shared_ptr<Graph> graph,
337     std::vector<IValue> sample_inputs) {
338   // Enable TensorExpr fusion with dynamic shapes
339   setTensorExprDynamicShapeFusionEnabled(true);
340   GRAPH_DEBUG("Graph before tracing: ", *graph);
341   auto traced_graph = TraceGraph(graph, sample_inputs);
342   GRAPH_DEBUG("Graph after tracing: ", *traced_graph);
343   FuseTensorExprs(
344       traced_graph,
345       /*min_group_size*/ 2,
346       /*add_composed_op*/ true,
347       /*fuse_to_dynamic_shapes*/ true);
348   RemoveTensorTypeSpecializations(graph);
349   inlineFallbackGraphs(traced_graph);
350   graph->block()->clear();
351   graph->block()->cloneFrom(traced_graph->block(), nullptr);
352   GRAPH_DUMP("Graph after fusion: ", graph);
353 }
354 
355 } // namespace torch::jit
356