1 #include <torch/csrc/jit/passes/concat_opt.h>
2
3 #include <algorithm>
4 #include <deque>
5 #include <unordered_map>
6 #include <unordered_set>
7 #include <vector>
8
9 #include <c10/util/ssize.h>
10 #include <torch/csrc/jit/ir/alias_analysis.h>
11 #include <torch/csrc/jit/ir/ir.h>
12 #include <torch/csrc/jit/ir/named_value.h>
13 #include <torch/csrc/jit/jit_log.h>
14 #include <torch/csrc/jit/passes/constant_pooling.h>
15 #include <torch/csrc/jit/passes/dead_code_elimination.h>
16 #include <torch/csrc/jit/passes/remove_mutation.h>
17 #include <torch/csrc/jit/runtime/graph_iterator.h>
18
19 namespace torch::jit {
20
21 namespace {
22
removeCatNodeFromGraph(Node * n)23 void removeCatNodeFromGraph(Node* n) {
24 TORCH_INTERNAL_ASSERT(n->kind() == aten::cat);
25 auto inp_list = n->input(0);
26 GRAPH_UPDATE("Deleting\n", *n);
27 n->destroy();
28 if (!inp_list->hasUses()) {
29 GRAPH_UPDATE("Deleting\n", *inp_list->node());
30 inp_list->node()->destroy();
31 }
32 }
33
equal(at::ArrayRef<Value * > list1,at::ArrayRef<Value * > list2)34 bool equal(at::ArrayRef<Value*> list1, at::ArrayRef<Value*> list2) {
35 return list1.size() == list2.size() &&
36 std::equal(list1.begin(), list1.end(), list2.begin());
37 }
38
39 class ConcatCommonInputsEliminator {
40 public:
ConcatCommonInputsEliminator(std::shared_ptr<Graph> graph)41 explicit ConcatCommonInputsEliminator(std::shared_ptr<Graph> graph)
42 : graph_(std::move(graph)) {}
43
run()44 bool run() {
45 handleBlock(graph_->block());
46 return postprocess();
47 }
48
49 private:
handleBlock(Block * block)50 void handleBlock(Block* block) {
51 for (auto node : block->nodes()) {
52 if (node->kind() == prim::VarConcat) {
53 handleCat(node);
54 }
55 for (Block* block : node->blocks()) {
56 handleBlock(block);
57 }
58 }
59 }
60
handleCat(Node * node)61 void handleCat(Node* node) {
62 GRAPH_DEBUG("Considering cat node for CSE opt: ", node);
63
64 auto curr_all_inputs = node->inputs();
65 auto curr_tensor_inputs =
66 curr_all_inputs.slice(0, curr_all_inputs.size() - 1);
67 auto curr_dim = curr_all_inputs.back();
68
69 // Save the input list and the current cat node, so that this can be
70 // used for subsequent cat nodes, unless there are writes to this cat
71 // node. When there are writes to this cat node, its output does not
72 // represent this concatenated list beyond the writes. Currently, we do
73 // not perform such fine-grained analysis. So, if there are any writes to
74 // the output, we do not use this cat node for optimization here.
75 if (!getOrCreateAliasDb()->hasWriters(node->output())) {
76 concated_outputs_.insert(node);
77 }
78
79 if (curr_tensor_inputs.size() <= 2) {
80 // The case when concat has 2 input tensors could only be optimized if
81 // there is another concat of the exact same 2 input tensors. That case
82 // is expected to be handled by the CSE pass.
83 return;
84 }
85
86 // Now, we check if the first N-1 elements in %inputs appeared in any of
87 // the previous cat ops.
88 //
89 // Example:
90 // %11 = prim::VarConcat(%0, %1, <dim>)
91 // ...
92 // %13 = prim::VarConcat(%0, %1, %2, <dim>) // first 2 inputs same as %11
93 // ...
94 // = %13 ... // Use %13
95 //
96 // After CSE opt:
97 // %11 = prim::VarConcat(%0, %1, <dim>)
98 // ...
99 // %14 = prim::VarConcat(%11, %2, <dim>) // Replace first 2 inputs
100 // // with %11
101 // ...
102 // = %14 ... // Replace use of %13 with %14
103
104 auto curr_tensor_inputs_prefix =
105 curr_tensor_inputs.slice(0, curr_tensor_inputs.size() - 1);
106 for (const auto& prev : concated_outputs_) {
107 auto prev_all_inputs = prev->inputs();
108 auto prev_tensor_inputs =
109 prev_all_inputs.slice(0, prev_all_inputs.size() - 1);
110 auto prev_dim = prev_all_inputs.back();
111 if (equal(curr_tensor_inputs_prefix, prev_tensor_inputs) &&
112 curr_dim == prev_dim) {
113 if (!node->isDominatedBy(prev)) {
114 // We can't use the previous concatenated output if it does not
115 // dominate the current concat node.
116 continue;
117 }
118
119 std::vector<Value*> new_inputs = {
120 prev->output(), curr_tensor_inputs.back(), curr_dim};
121 auto new_concat =
122 node->owningGraph()->create(prim::VarConcat, new_inputs);
123 new_concat->output()->setType(node->output()->type());
124 concats_to_replace_[node] = new_concat;
125 return;
126 }
127 }
128
129 // Now, we check if the last N-1 elements in %inputs appeared in any of
130 // the previous cat ops.
131 //
132 // Example:
133 // %10 = prim::ListConstruct(%1, %2)
134 // %11 = aten::cat(%10, ...)
135 // ...
136 // %12 = prim::ListConstruct(%0, %1, %2) // last 2 inputs same as %11
137 // %13 = aten::cat(%12, ...)
138 // ...
139 // = %13 ... // Use %13
140 //
141 // After CSE opt:
142 // %10 = prim::ListConstruct(%0, %1)
143 // %11 = aten::cat(%10, ...)
144 // ...
145 // %12 = prim::ListConstruct(%0, %11) // Replace last 2 inputs with %11
146 // %13 = aten::cat(%12, ...)
147 // ...
148 // = %13 ... // Use %13
149 auto curr_tensor_inputs_suffix =
150 curr_tensor_inputs.slice(1, curr_tensor_inputs.size() - 1);
151 for (const auto& prev : concated_outputs_) {
152 auto prev_all_inputs = prev->inputs();
153 auto prev_tensor_inputs =
154 prev_all_inputs.slice(0, prev_all_inputs.size() - 1);
155 auto prev_dim = prev_all_inputs.back();
156 if (equal(curr_tensor_inputs_suffix, prev_tensor_inputs) &&
157 curr_dim == prev_dim) {
158 if (!node->isDominatedBy(prev)) {
159 // We can't use the previous concatenated list if it does not
160 // dominate the current list.
161 continue;
162 }
163
164 std::vector<Value*> new_inputs = {
165 curr_tensor_inputs.front(), prev->output(), curr_dim};
166 auto new_concat =
167 node->owningGraph()->create(prim::VarConcat, new_inputs);
168 new_concat->output()->setType(node->output()->type());
169 concats_to_replace_[node] = new_concat;
170 return;
171 }
172 }
173
174 // Do we need to handle other cases where N-2 or lesser elements from
175 // %inputs appear in any of the previous cat ops?
176 // TODO.
177 }
178
postprocess()179 bool postprocess() {
180 // Replace the list nodes that have been marked.
181 bool changed = false;
182 for (auto it : concats_to_replace_) {
183 auto curr_node = it.first;
184 auto new_node = it.second;
185 GRAPH_UPDATE("Inserting\n", *new_node, "before\n", *curr_node);
186 new_node->insertBefore(curr_node);
187 GRAPH_UPDATE("Replacing uses of\n", *curr_node, "with\n", *new_node);
188 curr_node->output()->replaceAllUsesWith(new_node->output());
189 GRAPH_UPDATE("Deleting\n", *curr_node);
190 curr_node->destroy();
191 changed = true;
192 }
193 return changed;
194 }
195
getOrCreateAliasDb()196 AliasDb* getOrCreateAliasDb() {
197 if (!aliasDb_) {
198 aliasDb_ = std::make_unique<AliasDb>(graph_);
199 }
200 return aliasDb_.get();
201 }
202
203 std::shared_ptr<Graph> graph_;
204 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
205
206 std::unordered_set<Node*> concated_outputs_;
207 std::unordered_map<Node*, Node*> concats_to_replace_;
208 };
209
210 } // namespace
211
EliminateConcatCommonInputs(const std::shared_ptr<Graph> & graph)212 bool EliminateConcatCommonInputs(const std::shared_ptr<Graph>& graph) {
213 GRAPH_DUMP("Before eliminating Concat common inputs", graph);
214 bool changed = ConcatCommonInputsEliminator(graph).run();
215 if (changed) {
216 GRAPH_DUMP("After eliminating Concat common inputs", graph);
217 }
218 return changed;
219 }
220
221 namespace {
222
223 class ConcatExpander {
224 public:
ConcatExpander(std::shared_ptr<Graph> graph)225 explicit ConcatExpander(std::shared_ptr<Graph> graph)
226 : graph_(std::move(graph)) {}
227
run()228 void run() {
229 handleBlock(graph_->block());
230 cleanupExpandedCatOps();
231 GRAPH_DUMP("Before reusing copy buffers: ", graph_);
232 reuseBuffersInCopies();
233 }
234
235 private:
handleBlock(Block * block)236 void handleBlock(Block* block) {
237 for (auto node : block->nodes()) {
238 if (node->kind() == aten::cat) {
239 expandCat(node);
240 }
241 for (Block* block : node->blocks()) {
242 handleBlock(block);
243 }
244 }
245 }
246
247 // Expand cat node into multiple copy nodes.
248 //
249 // Example:
250 // %2 = aten::clamp(%0, ...)
251 // %3 = aten::clamp(%1, ...)
252 // %10 = prim::ListConstruct(%2, %3)
253 // %11 = aten::cat(%10, ...)
254 // ...
255 // = %11 ... // Use %11
256 //
257 // After expanding cat:
258 // %2 = aten::clamp(%0, ...)
259 // %3 = aten::clamp(%1, ...)
260 // %20 = aten::empty(...) // cat output buffer
261 // %21 = aten::slice(%20, ...) // slice for %2
262 // %22 = aten::copy_(%21, %2) // copy %2
263 // %23 = aten::slice(%20, ...) // slice for %3
264 // %24 = aten::copy_(%23, %3) // copy %3
265 // ...
266 // = %20 ... // Use %20 in place of %11
expandCat(Node * node)267 void expandCat(Node* node) {
268 GRAPH_DEBUG("Considering cat node for expansion: ", node);
269 // Do not optimize cat nodes whose inputs are mutated in the graph.
270 // TODO: Improve this by checking if it is mutated in the graph region
271 // where this optimization is applied.
272 if (getOrCreateAliasDb()->hasWriters(node->input(0))) {
273 return;
274 }
275 if (node->input(0)->node()->kind() != prim::ListConstruct) {
276 // Unknown form of input to `cat` op.
277 return;
278 }
279 if (!allShapesAreKnown(node)) {
280 // Can't expand when shapes are not known for the `cat` op.
281 return;
282 }
283 for (auto cat_inp : node->input(0)->node()->inputs()) {
284 if (!shapeIsKnown(cat_inp)) {
285 // Can't expand when shapes of the inputs to `cat` are not known.
286 return;
287 }
288 }
289 // TODO: Handle non-contiguous Tensors.
290 // For example, how to handle the cases where the inputs are all channels
291 // last?
292
293 auto maybe_cat_dim = constant_as<int64_t>(node->input(1));
294 if (!maybe_cat_dim) {
295 // Can't expand when cat dimension is not a constant.
296 return;
297 }
298 auto cat_dim_value = maybe_cat_dim.value();
299 auto cat_dim = node->input(1);
300
301 // Set the insertion point to the current `cat` node.
302 WithInsertPoint guard(node);
303 auto none = graph_->insertConstant(IValue());
304 auto one = graph_->insertConstant(1);
305
306 // Insert the constants needed for the `cat` output buffer size.
307 auto tensortype = node->output()->type()->expect<TensorType>();
308 TORCH_INTERNAL_ASSERT(tensortype);
309 auto tensortype_sizes = tensortype->sizes();
310 std::vector<Value*> cat_out_size;
311 for (size_t i = 0; i < tensortype_sizes.size(); ++i) {
312 cat_out_size.push_back(graph_->insertConstant(tensortype_sizes[i]));
313 }
314
315 // Create a list of int for `cat` output buffer size.
316 auto cat_out_size_list = graph_->createList(IntType::get(), cat_out_size);
317 cat_out_size_list->insertBefore(node);
318
319 // Create an empty buffer to be used as `cat` output buffer.
320 // TODO: Handle tensors with different dtype, layout, device, memory
321 // format, etc.
322 auto cat_out_empty = graph_->create(
323 aten::empty,
324 {cat_out_size_list->output(), none, none, none, none, none});
325 cat_out_empty->insertBefore(node);
326
327 // For every input to this `cat` node:
328 // * Create a slice of `cat` output buffer.
329 auto cat_out_value = cat_out_empty->output();
330 auto cat_inp_list = node->input(0)->node();
331 int64_t start_idx = 0;
332 auto start = graph_->insertConstant(start_idx);
333 for (auto cat_inp : cat_inp_list->inputs()) {
334 // Create a slice of the cat output buffer that correspond to
335 // this input size and position in the output.
336 auto cat_inp_tensor_type =
337 dynamic_cast<TensorType*>(cat_inp->type().get());
338 TORCH_INTERNAL_ASSERT(cat_inp_tensor_type);
339 TORCH_INTERNAL_ASSERT(cat_inp_tensor_type->dim());
340 auto cat_inp_tensortype_sizes = cat_inp_tensor_type->sizes();
341 auto end_idx = start_idx + *cat_inp_tensortype_sizes[cat_dim_value];
342 auto end = graph_->insertConstant(end_idx);
343
344 auto slice = graph_->create(
345 aten::slice, {cat_out_value, cat_dim, start, end, one});
346 GRAPH_UPDATE("Inserting\n", *slice, "before\n", *node);
347 slice->insertBefore(node);
348 slices_added_.push_back(slice);
349
350 // Insert a copy from this input to the output slice.
351 auto copy = graph_->create(aten::copy_, {slice->output(), cat_inp});
352 GRAPH_UPDATE("Inserting\n", *copy, "before\n", *node);
353 copy->insertBefore(node);
354 copies_added_.push_back(copy);
355
356 start_idx = end_idx;
357 start = end;
358 }
359
360 // Replace the uses of `cat` node with the cat output buffer.
361 replace_uses_with_[node->output()] = cat_out_value;
362 nodes_to_remove_.insert(node);
363 }
364
shapeIsKnown(Value * v)365 bool shapeIsKnown(Value* v) {
366 if (v->type()->cast<TensorType>()) {
367 if (!v->isCompleteTensor()) {
368 return false;
369 }
370 if (*v->type()->castRaw<TensorType>()->dim() == 0) {
371 return false;
372 }
373 }
374 return true;
375 }
allShapesAreKnown(Node * node)376 bool allShapesAreKnown(Node* node) {
377 // TODO: Relax the checks to support dynamic shapes
378 for (Value* input : node->inputs()) {
379 if (!shapeIsKnown(input)) {
380 return false;
381 }
382 }
383 for (Value* output : node->outputs()) {
384 if (!shapeIsKnown(output)) {
385 return false;
386 }
387 }
388 return true;
389 }
390
cleanupExpandedCatOps()391 void cleanupExpandedCatOps() {
392 for (auto it : replace_uses_with_) {
393 GRAPH_UPDATE(
394 "Replacing uses of\n",
395 *it.first->node(),
396 "with\n",
397 *it.second->node());
398 it.first->replaceAllUsesWith(it.second);
399 }
400 for (auto n : nodes_to_remove_) {
401 removeCatNodeFromGraph(n);
402 }
403 }
404
moveBefore(Node * node,Node * before)405 void moveBefore(Node* node, Node* before) {
406 // In order to move a node before another node, we need to move
407 // all the nodes it depends on as well.
408 for (auto inp : node->inputs()) {
409 moveBefore(inp->node(), before);
410 }
411 node->moveBefore(before);
412 }
413
414 // Reuse buffers in copies wherever possible.
415 //
416 // For example, consider the following sequence of ops:
417 // %10 = prim::ListConstruct(%0, %1)
418 // %11 = aten::cat(%10, ...)
419 // ...
420 // %12 = prim::ListConstruct(%11, %2) // Uses the result of above cat
421 // %13 = aten::cat(%12, ...)
422 //
423 // Once these cat ops are expanded into copies, we will have two buffers; one
424 // for %11 and another for %13. This can be optimized by using only one
425 // buffer. We can only have the buffer that represents %13 and use a view
426 // (slice) of that one as the buffer for %11.
427 //
428 // If any of the copies added earlier has `aten::empty` as its source,
429 // those cases can be replaced with a single buffer.
430 //
431 // Example:
432 // %20 = aten::empty(...) // cat.1 output buffer
433 // %21 = aten::slice(%20, ...)
434 // %22 = aten::copy_(%21, %2)
435 // %23 = aten::slice(%20, ...)
436 // %24 = aten::copy_(%23, %3)
437 // ...
438 // %30 = aten::empty(...) // cat.2 output buffer
439 // %31 = aten::slice(%30, ...)
440 // %32 = aten::copy_(%31, %20) // src of copy is aten::empty
441 // // so, we reuse this buffer above
442 // %33 = aten::slice(%30, ...)
443 // %34 = aten::copy_(%33, %4)
444 //
445 // After reusing copy buffers:
446 // %30 = aten::empty(...) // cat.2 output buffer
447 // %31 = aten::slice(%30, ...) // move %31 and inputs before %20
448 // %21 = aten::slice(%31, ...) // use %31 in place of %20
449 // %22 = aten::copy_(%21, %2)
450 // %23 = aten::slice(%31, ...) // use %31 in place of %20
451 // %24 = aten::copy_(%23, %3)
452 // ...
453 // ... // copy to %31 is now removed
454 // %33 = aten::slice(%30, ...)
455 // %34 = aten::copy_(%33, %4)
reuseBuffersInCopies()456 void reuseBuffersInCopies() {
457 for (auto copy : copies_added_) {
458 auto src = copy->input(1);
459 auto dst = copy->input(0);
460 if (src->node()->kind() != aten::empty) {
461 continue;
462 }
463
464 // Move the destination node before the source.
465 GRAPH_UPDATE("Moving\n", *dst->node(), "before\n", *src->node());
466 moveBefore(dst->node(), src->node());
467
468 GRAPH_UPDATE("Replacing\n", *src->node(), "with\n", *dst->node());
469 src->replaceAllUsesWith(dst);
470
471 GRAPH_UPDATE("Deleting\n", *src->node());
472 src->node()->destroy();
473
474 GRAPH_UPDATE("Deleting\n", *copy);
475 copy->destroy();
476 }
477 }
478
getOrCreateAliasDb()479 AliasDb* getOrCreateAliasDb() {
480 if (!aliasDb_) {
481 aliasDb_ = std::make_unique<AliasDb>(graph_);
482 }
483 return aliasDb_.get();
484 }
485
486 std::shared_ptr<Graph> graph_;
487 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
488
489 std::unordered_set<Node*> nodes_to_remove_;
490 std::unordered_map<Value*, Value*> replace_uses_with_;
491 std::vector<Node*> copies_added_;
492 std::vector<Node*> slices_added_;
493 };
494
495 } // namespace
496
ExpandConcatAndEliminateRedundancy(const std::shared_ptr<Graph> & graph)497 void ExpandConcatAndEliminateRedundancy(const std::shared_ptr<Graph>& graph) {
498 ConcatExpander(graph).run();
499 GRAPH_DUMP("After expanding Concat and eliminating redundancy", graph);
500 }
501
502 namespace {
503
determineUsageIdx(Value * value,Node * user)504 size_t determineUsageIdx(Value* value, Node* user) {
505 const auto idx =
506 std::find(user->inputs().begin(), user->inputs().end(), value) -
507 user->inputs().begin();
508 using c10::ssize;
509 TORCH_CHECK(idx != ssize(user->inputs()));
510 return idx;
511 }
512
getConcatInputs(Node * concat)513 std::vector<Value*> getConcatInputs(Node* concat) {
514 TORCH_CHECK(concat->kind() == aten::cat);
515 auto* list = concat->input(0);
516 auto* list_construct = list->node();
517 TORCH_CHECK(list_construct->kind() == prim::ListConstruct);
518 return list_construct->inputs().vec();
519 }
520
521 class ConcatCombiner {
522 public:
ConcatCombiner(std::shared_ptr<Graph> graph)523 explicit ConcatCombiner(std::shared_ptr<Graph> graph)
524 : graph_(std::move(graph)), aliasDb_(graph_) {}
525
run()526 bool run() {
527 collectOptimizableConcats();
528 bool changed = combineConcats();
529 if (changed) {
530 EliminateDeadCode(graph_);
531 }
532 return changed;
533 }
534
535 private:
536 // Given a concat node, see if it can be optimized with another.
537 // If so, add a CombinablePair to combinable_concats_.
handleConcat(Node * node)538 void handleConcat(Node* node) {
539 auto* list = node->input(0);
540 auto* list_node = list->node();
541
542 const auto dim_opt = toIValue(node->input(1));
543 // We need to be able to determine dim statically to match it with another
544 // concat.
545 if (!dim_opt || !dim_opt->isInt()) {
546 return;
547 }
548 const auto dim = dim_opt->toInt();
549
550 // Check that the input of this node is an unmodified list construct
551 if (list_node->kind() != prim::ListConstruct ||
552 !aliasDb_.couldMoveBeforeTopologically(list_node, node)) {
553 return;
554 }
555
556 // Check that the only output of this node is used in an unmodified list
557 // construct.
558 const auto& concat_uses = node->output()->uses();
559 if (concat_uses.size() != 1) {
560 return;
561 }
562
563 auto* next_list = concat_uses[0].user;
564 if (next_list->kind() != prim::ListConstruct) {
565 return;
566 }
567
568 const auto& next_list_uses = next_list->output()->uses();
569 if (next_list_uses.size() != 1) {
570 return;
571 }
572
573 auto* next_concat = next_list_uses[0].user;
574
575 if (next_concat->kind() == aten::cat) {
576 // Dimension must be determined statically and match the one we've already
577 // seen.
578 const auto next_dim_opt = toIValue(next_concat->input(1));
579 if (!next_dim_opt || next_dim_opt->toInt() != dim) {
580 return;
581 }
582 combinable_concats_.emplace_back(
583 node, next_concat, determineUsageIdx(node->output(), next_list));
584 }
585 }
586
collectOptimizableConcats()587 void collectOptimizableConcats() {
588 DepthFirstGraphNodeIterator graph_it(graph_);
589 for (auto* node = graph_it.next(); node != nullptr;
590 node = graph_it.next()) {
591 if (node->kind() == aten::cat) {
592 handleConcat(node);
593 }
594 }
595 }
596
createListConstruct(const std::deque<Value * > & inputs)597 Node* createListConstruct(const std::deque<Value*>& inputs) {
598 auto* output = graph_->create(prim::ListConstruct);
599 for (auto* v : inputs) {
600 output->addInput(v);
601 }
602 return output;
603 }
604
605 using ListConstructInputs = std::shared_ptr<std::deque<Value*>>;
606 // Construct a map (concat node) -> (new list inputs for this node).
607 // std::deque is used so we can do O(1) insertions to the front.
getListConstructInputs()608 std::unordered_map<Node*, ListConstructInputs> getListConstructInputs() {
609 std::unordered_map<Node*, ListConstructInputs> cur_list_construct_inputs;
610 for (const auto& combinable : combinable_concats_) {
611 // Combine the list inputs of first_concat with those of second_concat
612 const auto& inputs_to_add = getConcatInputs(combinable.second_concat);
613
614 auto it = cur_list_construct_inputs.find(combinable.first_concat);
615 std::shared_ptr<std::deque<Value*>> cur_list;
616 if (it != cur_list_construct_inputs.end()) {
617 cur_list = it->second;
618 // We're moving all inputs to second_concat.
619 cur_list_construct_inputs.erase(combinable.first_concat);
620 } else {
621 cur_list = std::make_shared<std::deque<Value*>>();
622 }
623 cur_list_construct_inputs.emplace(combinable.second_concat, cur_list);
624
625 // If cur_list is not empty, it's guaranteed to already contain all of
626 // first_concat's inputs.
627 if (cur_list->empty()) {
628 const auto& starting_values = getConcatInputs(combinable.first_concat);
629 cur_list->insert(
630 cur_list->end(), starting_values.begin(), starting_values.end());
631 }
632
633 cur_list->insert(
634 cur_list->begin(),
635 inputs_to_add.begin(),
636 inputs_to_add.begin() + combinable.idx);
637
638 cur_list->insert(
639 cur_list->end(),
640 inputs_to_add.begin() + combinable.idx + 1,
641 inputs_to_add.end());
642 }
643 return cur_list_construct_inputs;
644 }
645
combineConcats()646 bool combineConcats() {
647 if (combinable_concats_.empty()) {
648 return false;
649 }
650
651 auto list_construct_inputs = getListConstructInputs();
652
653 for (const auto& node_and_new_list : list_construct_inputs) {
654 auto* node = node_and_new_list.first;
655 auto& inputs = node_and_new_list.second;
656
657 auto* new_list_construct = createListConstruct(*inputs);
658 auto* old_list_construct = node->input(0)->node();
659 new_list_construct->output()->setType(
660 old_list_construct->output()->type());
661 new_list_construct->insertBefore(node);
662 old_list_construct->replaceAllUsesWith(new_list_construct);
663 }
664 return true;
665 }
666
667 // Represents an optimizable pair of concat nodes.
668 // - first_concat must appear before second_concat
669 // - idx is the index where first_concat's inputs must be inserted into
670 // second_concat's new inputs.
671 // Example:
672 // %inputs.1 = prim::ListConstruct(%0, %0)
673 // %concat.1 = aten::cat(%inputs.1, %dim)
674 // %inputs.2 = prim::ListConstruct(%1, %concat.1, %1)
675 // %concat.2 = aten::cat(%inputs.2, %dim)
676 // -> first_concat = &concat.1, second_concat = &concat.2, idx = 1
677 struct CombinableConcat {
CombinableConcattorch::jit::__anon0ea1e2b70311::ConcatCombiner::CombinableConcat678 CombinableConcat(Node* a, Node* b, size_t i)
679 : first_concat(a), second_concat(b), idx(i) {}
680
681 Node* first_concat;
682 Node* second_concat;
683 size_t idx;
684 };
685
686 std::vector<CombinableConcat> combinable_concats_;
687
688 std::shared_ptr<Graph> graph_;
689 AliasDb aliasDb_;
690 };
691
692 } // namespace
693
CombineConcats(const std::shared_ptr<Graph> & graph)694 bool CombineConcats(const std::shared_ptr<Graph>& graph) {
695 bool changed = ConcatCombiner(graph).run();
696 GRAPH_DUMP("After combining concats", graph);
697 return changed;
698 }
699
700 } // namespace torch::jit
701