xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/concat_opt.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/concat_opt.h>
2 
3 #include <algorithm>
4 #include <deque>
5 #include <unordered_map>
6 #include <unordered_set>
7 #include <vector>
8 
9 #include <c10/util/ssize.h>
10 #include <torch/csrc/jit/ir/alias_analysis.h>
11 #include <torch/csrc/jit/ir/ir.h>
12 #include <torch/csrc/jit/ir/named_value.h>
13 #include <torch/csrc/jit/jit_log.h>
14 #include <torch/csrc/jit/passes/constant_pooling.h>
15 #include <torch/csrc/jit/passes/dead_code_elimination.h>
16 #include <torch/csrc/jit/passes/remove_mutation.h>
17 #include <torch/csrc/jit/runtime/graph_iterator.h>
18 
19 namespace torch::jit {
20 
21 namespace {
22 
removeCatNodeFromGraph(Node * n)23 void removeCatNodeFromGraph(Node* n) {
24   TORCH_INTERNAL_ASSERT(n->kind() == aten::cat);
25   auto inp_list = n->input(0);
26   GRAPH_UPDATE("Deleting\n", *n);
27   n->destroy();
28   if (!inp_list->hasUses()) {
29     GRAPH_UPDATE("Deleting\n", *inp_list->node());
30     inp_list->node()->destroy();
31   }
32 }
33 
equal(at::ArrayRef<Value * > list1,at::ArrayRef<Value * > list2)34 bool equal(at::ArrayRef<Value*> list1, at::ArrayRef<Value*> list2) {
35   return list1.size() == list2.size() &&
36       std::equal(list1.begin(), list1.end(), list2.begin());
37 }
38 
39 class ConcatCommonInputsEliminator {
40  public:
ConcatCommonInputsEliminator(std::shared_ptr<Graph> graph)41   explicit ConcatCommonInputsEliminator(std::shared_ptr<Graph> graph)
42       : graph_(std::move(graph)) {}
43 
run()44   bool run() {
45     handleBlock(graph_->block());
46     return postprocess();
47   }
48 
49  private:
handleBlock(Block * block)50   void handleBlock(Block* block) {
51     for (auto node : block->nodes()) {
52       if (node->kind() == prim::VarConcat) {
53         handleCat(node);
54       }
55       for (Block* block : node->blocks()) {
56         handleBlock(block);
57       }
58     }
59   }
60 
handleCat(Node * node)61   void handleCat(Node* node) {
62     GRAPH_DEBUG("Considering cat node for CSE opt: ", node);
63 
64     auto curr_all_inputs = node->inputs();
65     auto curr_tensor_inputs =
66         curr_all_inputs.slice(0, curr_all_inputs.size() - 1);
67     auto curr_dim = curr_all_inputs.back();
68 
69     // Save the input list and the current cat node, so that this can be
70     // used for subsequent cat nodes, unless there are writes to this cat
71     // node. When there are writes to this cat node, its output does not
72     // represent this concatenated list beyond the writes. Currently, we do
73     // not perform such fine-grained analysis. So, if there are any writes to
74     // the output, we do not use this cat node for optimization here.
75     if (!getOrCreateAliasDb()->hasWriters(node->output())) {
76       concated_outputs_.insert(node);
77     }
78 
79     if (curr_tensor_inputs.size() <= 2) {
80       // The case when concat has 2 input tensors could only be optimized if
81       // there is another concat of the exact same 2 input tensors. That case
82       // is expected to be handled by the CSE pass.
83       return;
84     }
85 
86     // Now, we check if the first N-1 elements in %inputs appeared in any of
87     // the previous cat ops.
88     //
89     // Example:
90     //    %11 = prim::VarConcat(%0, %1, <dim>)
91     //    ...
92     //    %13 = prim::VarConcat(%0, %1, %2, <dim>) // first 2 inputs same as %11
93     //    ...
94     //        = %13 ... // Use %13
95     //
96     // After CSE opt:
97     //    %11 = prim::VarConcat(%0, %1, <dim>)
98     //    ...
99     //    %14 = prim::VarConcat(%11, %2, <dim>) // Replace first 2 inputs
100     //                                          // with %11
101     //    ...
102     //        = %14 ... // Replace use of %13 with %14
103 
104     auto curr_tensor_inputs_prefix =
105         curr_tensor_inputs.slice(0, curr_tensor_inputs.size() - 1);
106     for (const auto& prev : concated_outputs_) {
107       auto prev_all_inputs = prev->inputs();
108       auto prev_tensor_inputs =
109           prev_all_inputs.slice(0, prev_all_inputs.size() - 1);
110       auto prev_dim = prev_all_inputs.back();
111       if (equal(curr_tensor_inputs_prefix, prev_tensor_inputs) &&
112           curr_dim == prev_dim) {
113         if (!node->isDominatedBy(prev)) {
114           // We can't use the previous concatenated output if it does not
115           // dominate the current concat node.
116           continue;
117         }
118 
119         std::vector<Value*> new_inputs = {
120             prev->output(), curr_tensor_inputs.back(), curr_dim};
121         auto new_concat =
122             node->owningGraph()->create(prim::VarConcat, new_inputs);
123         new_concat->output()->setType(node->output()->type());
124         concats_to_replace_[node] = new_concat;
125         return;
126       }
127     }
128 
129     // Now, we check if the last N-1 elements in %inputs appeared in any of
130     // the previous cat ops.
131     //
132     // Example:
133     //    %10 = prim::ListConstruct(%1, %2)
134     //    %11 = aten::cat(%10, ...)
135     //    ...
136     //    %12 = prim::ListConstruct(%0, %1, %2)  // last 2 inputs same as %11
137     //    %13 = aten::cat(%12, ...)
138     //    ...
139     //        = %13 ... // Use %13
140     //
141     // After CSE opt:
142     //    %10 = prim::ListConstruct(%0, %1)
143     //    %11 = aten::cat(%10, ...)
144     //    ...
145     //    %12 = prim::ListConstruct(%0, %11) // Replace last 2 inputs with %11
146     //    %13 = aten::cat(%12, ...)
147     //    ...
148     //        = %13 ... // Use %13
149     auto curr_tensor_inputs_suffix =
150         curr_tensor_inputs.slice(1, curr_tensor_inputs.size() - 1);
151     for (const auto& prev : concated_outputs_) {
152       auto prev_all_inputs = prev->inputs();
153       auto prev_tensor_inputs =
154           prev_all_inputs.slice(0, prev_all_inputs.size() - 1);
155       auto prev_dim = prev_all_inputs.back();
156       if (equal(curr_tensor_inputs_suffix, prev_tensor_inputs) &&
157           curr_dim == prev_dim) {
158         if (!node->isDominatedBy(prev)) {
159           // We can't use the previous concatenated list if it does not
160           // dominate the current list.
161           continue;
162         }
163 
164         std::vector<Value*> new_inputs = {
165             curr_tensor_inputs.front(), prev->output(), curr_dim};
166         auto new_concat =
167             node->owningGraph()->create(prim::VarConcat, new_inputs);
168         new_concat->output()->setType(node->output()->type());
169         concats_to_replace_[node] = new_concat;
170         return;
171       }
172     }
173 
174     // Do we need to handle other cases where N-2 or lesser elements from
175     // %inputs appear in any of the previous cat ops?
176     // TODO.
177   }
178 
postprocess()179   bool postprocess() {
180     // Replace the list nodes that have been marked.
181     bool changed = false;
182     for (auto it : concats_to_replace_) {
183       auto curr_node = it.first;
184       auto new_node = it.second;
185       GRAPH_UPDATE("Inserting\n", *new_node, "before\n", *curr_node);
186       new_node->insertBefore(curr_node);
187       GRAPH_UPDATE("Replacing uses of\n", *curr_node, "with\n", *new_node);
188       curr_node->output()->replaceAllUsesWith(new_node->output());
189       GRAPH_UPDATE("Deleting\n", *curr_node);
190       curr_node->destroy();
191       changed = true;
192     }
193     return changed;
194   }
195 
getOrCreateAliasDb()196   AliasDb* getOrCreateAliasDb() {
197     if (!aliasDb_) {
198       aliasDb_ = std::make_unique<AliasDb>(graph_);
199     }
200     return aliasDb_.get();
201   }
202 
203   std::shared_ptr<Graph> graph_;
204   std::unique_ptr<AliasDb> aliasDb_ = nullptr;
205 
206   std::unordered_set<Node*> concated_outputs_;
207   std::unordered_map<Node*, Node*> concats_to_replace_;
208 };
209 
210 } // namespace
211 
EliminateConcatCommonInputs(const std::shared_ptr<Graph> & graph)212 bool EliminateConcatCommonInputs(const std::shared_ptr<Graph>& graph) {
213   GRAPH_DUMP("Before eliminating Concat common inputs", graph);
214   bool changed = ConcatCommonInputsEliminator(graph).run();
215   if (changed) {
216     GRAPH_DUMP("After eliminating Concat common inputs", graph);
217   }
218   return changed;
219 }
220 
221 namespace {
222 
223 class ConcatExpander {
224  public:
ConcatExpander(std::shared_ptr<Graph> graph)225   explicit ConcatExpander(std::shared_ptr<Graph> graph)
226       : graph_(std::move(graph)) {}
227 
run()228   void run() {
229     handleBlock(graph_->block());
230     cleanupExpandedCatOps();
231     GRAPH_DUMP("Before reusing copy buffers: ", graph_);
232     reuseBuffersInCopies();
233   }
234 
235  private:
handleBlock(Block * block)236   void handleBlock(Block* block) {
237     for (auto node : block->nodes()) {
238       if (node->kind() == aten::cat) {
239         expandCat(node);
240       }
241       for (Block* block : node->blocks()) {
242         handleBlock(block);
243       }
244     }
245   }
246 
247   // Expand cat node into multiple copy nodes.
248   //
249   // Example:
250   //     %2 = aten::clamp(%0, ...)
251   //     %3 = aten::clamp(%1, ...)
252   //     %10 = prim::ListConstruct(%2, %3)
253   //     %11 = aten::cat(%10, ...)
254   //     ...
255   //         = %11 ... // Use %11
256   //
257   // After expanding cat:
258   //     %2 = aten::clamp(%0, ...)
259   //     %3 = aten::clamp(%1, ...)
260   //     %20 = aten::empty(...)          // cat output buffer
261   //     %21 = aten::slice(%20, ...)     // slice for %2
262   //     %22 = aten::copy_(%21, %2)      // copy %2
263   //     %23 = aten::slice(%20, ...)     // slice for %3
264   //     %24 = aten::copy_(%23, %3)      // copy %3
265   //     ...
266   //         = %20 ... // Use %20 in place of %11
expandCat(Node * node)267   void expandCat(Node* node) {
268     GRAPH_DEBUG("Considering cat node for expansion: ", node);
269     // Do not optimize cat nodes whose inputs are mutated in the graph.
270     // TODO: Improve this by checking if it is mutated in the graph region
271     // where this optimization is applied.
272     if (getOrCreateAliasDb()->hasWriters(node->input(0))) {
273       return;
274     }
275     if (node->input(0)->node()->kind() != prim::ListConstruct) {
276       // Unknown form of input to `cat` op.
277       return;
278     }
279     if (!allShapesAreKnown(node)) {
280       // Can't expand when shapes are not known for the `cat` op.
281       return;
282     }
283     for (auto cat_inp : node->input(0)->node()->inputs()) {
284       if (!shapeIsKnown(cat_inp)) {
285         // Can't expand when shapes of the inputs to `cat` are not known.
286         return;
287       }
288     }
289     // TODO: Handle non-contiguous Tensors.
290     // For example, how to handle the cases where the inputs are all channels
291     // last?
292 
293     auto maybe_cat_dim = constant_as<int64_t>(node->input(1));
294     if (!maybe_cat_dim) {
295       // Can't expand when cat dimension is not a constant.
296       return;
297     }
298     auto cat_dim_value = maybe_cat_dim.value();
299     auto cat_dim = node->input(1);
300 
301     // Set the insertion point to the current `cat` node.
302     WithInsertPoint guard(node);
303     auto none = graph_->insertConstant(IValue());
304     auto one = graph_->insertConstant(1);
305 
306     // Insert the constants needed for the `cat` output buffer size.
307     auto tensortype = node->output()->type()->expect<TensorType>();
308     TORCH_INTERNAL_ASSERT(tensortype);
309     auto tensortype_sizes = tensortype->sizes();
310     std::vector<Value*> cat_out_size;
311     for (size_t i = 0; i < tensortype_sizes.size(); ++i) {
312       cat_out_size.push_back(graph_->insertConstant(tensortype_sizes[i]));
313     }
314 
315     // Create a list of int for `cat` output buffer size.
316     auto cat_out_size_list = graph_->createList(IntType::get(), cat_out_size);
317     cat_out_size_list->insertBefore(node);
318 
319     // Create an empty buffer to be used as `cat` output buffer.
320     // TODO: Handle tensors with different dtype, layout, device, memory
321     // format, etc.
322     auto cat_out_empty = graph_->create(
323         aten::empty,
324         {cat_out_size_list->output(), none, none, none, none, none});
325     cat_out_empty->insertBefore(node);
326 
327     // For every input to this `cat` node:
328     //   * Create a slice of `cat` output buffer.
329     auto cat_out_value = cat_out_empty->output();
330     auto cat_inp_list = node->input(0)->node();
331     int64_t start_idx = 0;
332     auto start = graph_->insertConstant(start_idx);
333     for (auto cat_inp : cat_inp_list->inputs()) {
334       // Create a slice of the cat output buffer that correspond to
335       // this input size and position in the output.
336       auto cat_inp_tensor_type =
337           dynamic_cast<TensorType*>(cat_inp->type().get());
338       TORCH_INTERNAL_ASSERT(cat_inp_tensor_type);
339       TORCH_INTERNAL_ASSERT(cat_inp_tensor_type->dim());
340       auto cat_inp_tensortype_sizes = cat_inp_tensor_type->sizes();
341       auto end_idx = start_idx + *cat_inp_tensortype_sizes[cat_dim_value];
342       auto end = graph_->insertConstant(end_idx);
343 
344       auto slice = graph_->create(
345           aten::slice, {cat_out_value, cat_dim, start, end, one});
346       GRAPH_UPDATE("Inserting\n", *slice, "before\n", *node);
347       slice->insertBefore(node);
348       slices_added_.push_back(slice);
349 
350       // Insert a copy from this input to the output slice.
351       auto copy = graph_->create(aten::copy_, {slice->output(), cat_inp});
352       GRAPH_UPDATE("Inserting\n", *copy, "before\n", *node);
353       copy->insertBefore(node);
354       copies_added_.push_back(copy);
355 
356       start_idx = end_idx;
357       start = end;
358     }
359 
360     // Replace the uses of `cat` node with the cat output buffer.
361     replace_uses_with_[node->output()] = cat_out_value;
362     nodes_to_remove_.insert(node);
363   }
364 
shapeIsKnown(Value * v)365   bool shapeIsKnown(Value* v) {
366     if (v->type()->cast<TensorType>()) {
367       if (!v->isCompleteTensor()) {
368         return false;
369       }
370       if (*v->type()->castRaw<TensorType>()->dim() == 0) {
371         return false;
372       }
373     }
374     return true;
375   }
allShapesAreKnown(Node * node)376   bool allShapesAreKnown(Node* node) {
377     // TODO: Relax the checks to support dynamic shapes
378     for (Value* input : node->inputs()) {
379       if (!shapeIsKnown(input)) {
380         return false;
381       }
382     }
383     for (Value* output : node->outputs()) {
384       if (!shapeIsKnown(output)) {
385         return false;
386       }
387     }
388     return true;
389   }
390 
cleanupExpandedCatOps()391   void cleanupExpandedCatOps() {
392     for (auto it : replace_uses_with_) {
393       GRAPH_UPDATE(
394           "Replacing uses of\n",
395           *it.first->node(),
396           "with\n",
397           *it.second->node());
398       it.first->replaceAllUsesWith(it.second);
399     }
400     for (auto n : nodes_to_remove_) {
401       removeCatNodeFromGraph(n);
402     }
403   }
404 
moveBefore(Node * node,Node * before)405   void moveBefore(Node* node, Node* before) {
406     // In order to move a node before another node, we need to move
407     // all the nodes it depends on as well.
408     for (auto inp : node->inputs()) {
409       moveBefore(inp->node(), before);
410     }
411     node->moveBefore(before);
412   }
413 
414   // Reuse buffers in copies wherever possible.
415   //
416   // For example, consider the following sequence of ops:
417   //     %10 = prim::ListConstruct(%0, %1)
418   //     %11 = aten::cat(%10, ...)
419   //     ...
420   //     %12 = prim::ListConstruct(%11, %2)  // Uses the result of above cat
421   //     %13 = aten::cat(%12, ...)
422   //
423   // Once these cat ops are expanded into copies, we will have two buffers; one
424   // for %11 and another for %13. This can be optimized by using only one
425   // buffer. We can only have the buffer that represents %13 and use a view
426   // (slice) of that one as the buffer for %11.
427   //
428   // If any of the copies added earlier has `aten::empty` as its source,
429   // those cases can be replaced with a single buffer.
430   //
431   // Example:
432   //     %20 = aten::empty(...)          // cat.1 output buffer
433   //     %21 = aten::slice(%20, ...)
434   //     %22 = aten::copy_(%21, %2)
435   //     %23 = aten::slice(%20, ...)
436   //     %24 = aten::copy_(%23, %3)
437   //     ...
438   //     %30 = aten::empty(...)          // cat.2 output buffer
439   //     %31 = aten::slice(%30, ...)
440   //     %32 = aten::copy_(%31, %20)     // src of copy is aten::empty
441   //                                     // so, we reuse this buffer above
442   //     %33 = aten::slice(%30, ...)
443   //     %34 = aten::copy_(%33, %4)
444   //
445   // After reusing copy buffers:
446   //     %30 = aten::empty(...)          // cat.2 output buffer
447   //     %31 = aten::slice(%30, ...)     // move %31 and inputs before %20
448   //     %21 = aten::slice(%31, ...)     // use %31 in place of %20
449   //     %22 = aten::copy_(%21, %2)
450   //     %23 = aten::slice(%31, ...)     // use %31 in place of %20
451   //     %24 = aten::copy_(%23, %3)
452   //     ...
453   //     ...                             // copy to %31 is now removed
454   //     %33 = aten::slice(%30, ...)
455   //     %34 = aten::copy_(%33, %4)
reuseBuffersInCopies()456   void reuseBuffersInCopies() {
457     for (auto copy : copies_added_) {
458       auto src = copy->input(1);
459       auto dst = copy->input(0);
460       if (src->node()->kind() != aten::empty) {
461         continue;
462       }
463 
464       // Move the destination node before the source.
465       GRAPH_UPDATE("Moving\n", *dst->node(), "before\n", *src->node());
466       moveBefore(dst->node(), src->node());
467 
468       GRAPH_UPDATE("Replacing\n", *src->node(), "with\n", *dst->node());
469       src->replaceAllUsesWith(dst);
470 
471       GRAPH_UPDATE("Deleting\n", *src->node());
472       src->node()->destroy();
473 
474       GRAPH_UPDATE("Deleting\n", *copy);
475       copy->destroy();
476     }
477   }
478 
getOrCreateAliasDb()479   AliasDb* getOrCreateAliasDb() {
480     if (!aliasDb_) {
481       aliasDb_ = std::make_unique<AliasDb>(graph_);
482     }
483     return aliasDb_.get();
484   }
485 
486   std::shared_ptr<Graph> graph_;
487   std::unique_ptr<AliasDb> aliasDb_ = nullptr;
488 
489   std::unordered_set<Node*> nodes_to_remove_;
490   std::unordered_map<Value*, Value*> replace_uses_with_;
491   std::vector<Node*> copies_added_;
492   std::vector<Node*> slices_added_;
493 };
494 
495 } // namespace
496 
ExpandConcatAndEliminateRedundancy(const std::shared_ptr<Graph> & graph)497 void ExpandConcatAndEliminateRedundancy(const std::shared_ptr<Graph>& graph) {
498   ConcatExpander(graph).run();
499   GRAPH_DUMP("After expanding Concat and eliminating redundancy", graph);
500 }
501 
502 namespace {
503 
determineUsageIdx(Value * value,Node * user)504 size_t determineUsageIdx(Value* value, Node* user) {
505   const auto idx =
506       std::find(user->inputs().begin(), user->inputs().end(), value) -
507       user->inputs().begin();
508   using c10::ssize;
509   TORCH_CHECK(idx != ssize(user->inputs()));
510   return idx;
511 }
512 
getConcatInputs(Node * concat)513 std::vector<Value*> getConcatInputs(Node* concat) {
514   TORCH_CHECK(concat->kind() == aten::cat);
515   auto* list = concat->input(0);
516   auto* list_construct = list->node();
517   TORCH_CHECK(list_construct->kind() == prim::ListConstruct);
518   return list_construct->inputs().vec();
519 }
520 
521 class ConcatCombiner {
522  public:
ConcatCombiner(std::shared_ptr<Graph> graph)523   explicit ConcatCombiner(std::shared_ptr<Graph> graph)
524       : graph_(std::move(graph)), aliasDb_(graph_) {}
525 
run()526   bool run() {
527     collectOptimizableConcats();
528     bool changed = combineConcats();
529     if (changed) {
530       EliminateDeadCode(graph_);
531     }
532     return changed;
533   }
534 
535  private:
536   // Given a concat node, see if it can be optimized with another.
537   // If so, add a CombinablePair to combinable_concats_.
handleConcat(Node * node)538   void handleConcat(Node* node) {
539     auto* list = node->input(0);
540     auto* list_node = list->node();
541 
542     const auto dim_opt = toIValue(node->input(1));
543     // We need to be able to determine dim statically to match it with another
544     // concat.
545     if (!dim_opt || !dim_opt->isInt()) {
546       return;
547     }
548     const auto dim = dim_opt->toInt();
549 
550     // Check that the input of this node is an unmodified list construct
551     if (list_node->kind() != prim::ListConstruct ||
552         !aliasDb_.couldMoveBeforeTopologically(list_node, node)) {
553       return;
554     }
555 
556     // Check that the only output of this node is used in an unmodified list
557     // construct.
558     const auto& concat_uses = node->output()->uses();
559     if (concat_uses.size() != 1) {
560       return;
561     }
562 
563     auto* next_list = concat_uses[0].user;
564     if (next_list->kind() != prim::ListConstruct) {
565       return;
566     }
567 
568     const auto& next_list_uses = next_list->output()->uses();
569     if (next_list_uses.size() != 1) {
570       return;
571     }
572 
573     auto* next_concat = next_list_uses[0].user;
574 
575     if (next_concat->kind() == aten::cat) {
576       // Dimension must be determined statically and match the one we've already
577       // seen.
578       const auto next_dim_opt = toIValue(next_concat->input(1));
579       if (!next_dim_opt || next_dim_opt->toInt() != dim) {
580         return;
581       }
582       combinable_concats_.emplace_back(
583           node, next_concat, determineUsageIdx(node->output(), next_list));
584     }
585   }
586 
collectOptimizableConcats()587   void collectOptimizableConcats() {
588     DepthFirstGraphNodeIterator graph_it(graph_);
589     for (auto* node = graph_it.next(); node != nullptr;
590          node = graph_it.next()) {
591       if (node->kind() == aten::cat) {
592         handleConcat(node);
593       }
594     }
595   }
596 
createListConstruct(const std::deque<Value * > & inputs)597   Node* createListConstruct(const std::deque<Value*>& inputs) {
598     auto* output = graph_->create(prim::ListConstruct);
599     for (auto* v : inputs) {
600       output->addInput(v);
601     }
602     return output;
603   }
604 
605   using ListConstructInputs = std::shared_ptr<std::deque<Value*>>;
606   // Construct a map (concat node) -> (new list inputs for this node).
607   // std::deque is used so we can do O(1) insertions to the front.
getListConstructInputs()608   std::unordered_map<Node*, ListConstructInputs> getListConstructInputs() {
609     std::unordered_map<Node*, ListConstructInputs> cur_list_construct_inputs;
610     for (const auto& combinable : combinable_concats_) {
611       // Combine the list inputs of first_concat with those of second_concat
612       const auto& inputs_to_add = getConcatInputs(combinable.second_concat);
613 
614       auto it = cur_list_construct_inputs.find(combinable.first_concat);
615       std::shared_ptr<std::deque<Value*>> cur_list;
616       if (it != cur_list_construct_inputs.end()) {
617         cur_list = it->second;
618         // We're moving all inputs to second_concat.
619         cur_list_construct_inputs.erase(combinable.first_concat);
620       } else {
621         cur_list = std::make_shared<std::deque<Value*>>();
622       }
623       cur_list_construct_inputs.emplace(combinable.second_concat, cur_list);
624 
625       // If cur_list is not empty, it's guaranteed to already contain all of
626       // first_concat's inputs.
627       if (cur_list->empty()) {
628         const auto& starting_values = getConcatInputs(combinable.first_concat);
629         cur_list->insert(
630             cur_list->end(), starting_values.begin(), starting_values.end());
631       }
632 
633       cur_list->insert(
634           cur_list->begin(),
635           inputs_to_add.begin(),
636           inputs_to_add.begin() + combinable.idx);
637 
638       cur_list->insert(
639           cur_list->end(),
640           inputs_to_add.begin() + combinable.idx + 1,
641           inputs_to_add.end());
642     }
643     return cur_list_construct_inputs;
644   }
645 
combineConcats()646   bool combineConcats() {
647     if (combinable_concats_.empty()) {
648       return false;
649     }
650 
651     auto list_construct_inputs = getListConstructInputs();
652 
653     for (const auto& node_and_new_list : list_construct_inputs) {
654       auto* node = node_and_new_list.first;
655       auto& inputs = node_and_new_list.second;
656 
657       auto* new_list_construct = createListConstruct(*inputs);
658       auto* old_list_construct = node->input(0)->node();
659       new_list_construct->output()->setType(
660           old_list_construct->output()->type());
661       new_list_construct->insertBefore(node);
662       old_list_construct->replaceAllUsesWith(new_list_construct);
663     }
664     return true;
665   }
666 
667   // Represents an optimizable pair of concat nodes.
668   // - first_concat must appear before second_concat
669   // - idx is the index where first_concat's inputs must be inserted into
670   //   second_concat's new inputs.
671   // Example:
672   //    %inputs.1 = prim::ListConstruct(%0, %0)
673   //    %concat.1 = aten::cat(%inputs.1, %dim)
674   //    %inputs.2 = prim::ListConstruct(%1, %concat.1, %1)
675   //    %concat.2 = aten::cat(%inputs.2, %dim)
676   // -> first_concat = &concat.1, second_concat = &concat.2, idx = 1
677   struct CombinableConcat {
CombinableConcattorch::jit::__anon0ea1e2b70311::ConcatCombiner::CombinableConcat678     CombinableConcat(Node* a, Node* b, size_t i)
679         : first_concat(a), second_concat(b), idx(i) {}
680 
681     Node* first_concat;
682     Node* second_concat;
683     size_t idx;
684   };
685 
686   std::vector<CombinableConcat> combinable_concats_;
687 
688   std::shared_ptr<Graph> graph_;
689   AliasDb aliasDb_;
690 };
691 
692 } // namespace
693 
CombineConcats(const std::shared_ptr<Graph> & graph)694 bool CombineConcats(const std::shared_ptr<Graph>& graph) {
695   bool changed = ConcatCombiner(graph).run();
696   GRAPH_DUMP("After combining concats", graph);
697   return changed;
698 }
699 
700 } // namespace torch::jit
701