xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/lower_tuples.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/lower_tuples.h>
2 
3 #include <ATen/core/functional.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/ir/constants.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/dead_code_elimination.h>
9 
10 #include <utility>
11 
12 namespace torch::jit {
13 
14 namespace {
15 
16 // operators where we expect to find tuples as inputs/outputs
17 // this is to assert we are only doing modifications when we know
18 // we can flatten tuples
19 std::unordered_set<Symbol> supported_ops = {
20     prim::If,
21     prim::Loop,
22     prim::Uninitialized,
23     prim::TupleUnpack,
24     prim::TupleConstruct,
25     prim::TupleIndex,
26     prim::TupleSlice,
27     prim::Param,
28     prim::Return,
29     prim::PythonOp,
30     aten::format,
31     prim::Uninitialized,
32     aten::__getitem__};
33 
34 // Flatten block inputs and insert a tuple construct in the block
flattenTupleInLoopParams(Node * n,size_t index)35 static void flattenTupleInLoopParams(Node* n, size_t index) {
36   auto input = n->inputs().at(index);
37   TupleTypePtr tt = input->type()->cast<TupleType>();
38   TORCH_INTERNAL_ASSERT(tt);
39 
40   Block* block = n->blocks().at(0);
41   Node* block_node = n;
42 
43   std::vector<Value*> new_node_inputs = {};
44   auto new_construct_node =
45       block->prependNode(block->owningGraph()->create(prim::TupleConstruct));
46   for (size_t j = 0; j < tt->elements().size(); ++j) {
47     auto new_block_in = block->insertInput(index + j);
48     new_construct_node->addInput(new_block_in);
49     block_node->insertInput(index + j + 1, input->node()->inputs().at(j));
50   }
51   new_construct_node->output()->setType(block->inputs().at(index - 1)->type());
52   new_construct_node->copyMetadata(n);
53   block->inputs().at(index - 1)->replaceAllUsesWith(
54       new_construct_node->output());
55   block->eraseInput(index - 1);
56   block_node->removeInput(index);
57 }
58 
59 // Flatten tuple outputs of the block node and append a TupleConstruct
60 // node after the block node if there is an outer block.
flattenTupleInBlockReturn(Node * n,size_t index)61 static void flattenTupleInBlockReturn(Node* n, size_t index) {
62   auto input = n->inputs().at(index);
63   Block* block = n->owningBlock();
64   Node* block_node = block->owningNode();
65   Node* new_construct_node = nullptr;
66   TupleTypePtr tt = input->type()->cast<TupleType>();
67   TORCH_INTERNAL_ASSERT(tt);
68 
69   // 1- Add flattened tuple to block outputs
70   for (size_t j = 0; j < tt->elements().size(); ++j) {
71     block->insertOutput(index + j + 1, input->node()->inputs().at(j));
72   }
73   block->eraseOutput(index);
74 
75   if (block_node == nullptr)
76     return;
77   // 2- For uses of the block node in the outer block,
78   // flatten the blocknode outputs and insert a tuple construct
79   // to replace that.
80   // Loop block has an extra element (iter counter)
81   if (block_node->kind() == prim::Loop)
82     index = index - 1;
83   auto tuple_output = block_node->outputs().at(index);
84   // When node has multiple blocks, do not flatten outputs on the second block
85   // again
86   if (!(tuple_output->type()->cast<TupleType>()))
87     return;
88 
89   new_construct_node = block->owningGraph()->create(prim::TupleConstruct);
90   new_construct_node->insertAfter(block_node);
91   for (size_t j = 0; j < tt->elements().size(); ++j) {
92     auto new_block_out = block_node->insertOutput(index + j + 1);
93     new_construct_node->addInput(new_block_out);
94   }
95   // Replace the block node with the new TupleConstruct node
96   new_construct_node->output()->setType(tuple_output->type());
97   new_construct_node->copyMetadata(block_node);
98   tuple_output->replaceAllUsesWith(new_construct_node->output());
99   block_node->eraseOutput(index);
100 }
101 
removeTupleNodes(Node * n,bool must_remove_tuples)102 void removeTupleNodes(Node* n, bool must_remove_tuples) {
103   if (n->kind() != prim::TupleUnpack && n->kind() != prim::TupleIndex &&
104       n->kind() != prim::TupleSlice) {
105     return;
106   }
107   // tuple index has two inputs, tuple and index
108   auto construct_node = n->inputs().at(0)->node();
109   if (construct_node->kind() != prim::TupleConstruct) {
110     if (must_remove_tuples) {
111       AT_ERROR(n->kind().toQualString(), " not matched to tuple construct");
112     }
113     return;
114   }
115   if (n->kind() == prim::TupleUnpack) {
116     for (size_t i = 0; i < n->outputs().size(); ++i) {
117       n->outputs()[i]->replaceAllUsesWith(construct_node->inputs().at(i));
118     }
119   } else if (n->kind() == prim::TupleIndex) {
120     auto idx = n->inputs().at(1);
121     auto maybe_int = constant_as<int64_t>(idx);
122     if (!maybe_int) {
123       if (must_remove_tuples) {
124         AT_ERROR(n->sourceRange(), "tuple index with non-constant index");
125       }
126       return;
127     }
128     auto int_idx = *maybe_int;
129     size_t len = construct_node->output()->type()->containedTypes().size();
130     if (int_idx < 0) {
131       int_idx += len;
132     }
133     // currently, we allow non-constant tuple index if the tuple is of one type.
134     // so we need to check bounds here
135     if (int_idx >= 0 && static_cast<size_t>(int_idx) < len) {
136       n->output()->replaceAllUsesWith(construct_node->inputs().at(int_idx));
137     }
138   } else if (n->kind() == prim::TupleSlice) {
139     std::vector<Value*> values;
140     int64_t beg = n->i(attr::beg);
141     int64_t end = n->i(attr::end);
142     for (int64_t i = beg; i < end; i += 1) {
143       values.push_back(construct_node->inputs().at(i));
144     }
145     auto graph = n->owningGraph();
146     auto tuple_out = graph->createTuple(values);
147     tuple_out->copyMetadata(n);
148     WithInsertPoint insert(n);
149     graph->insertNode(tuple_out);
150     n->output()->replaceAllUsesWith(tuple_out->output());
151   }
152 }
153 } // anonymous namespace
154 
155 static void LowerAllTuples(Block* block);
156 
RemoveTupleConstants(Node * n)157 static void RemoveTupleConstants(Node* n) {
158   if (!(n->kind() == prim::Constant &&
159         n->output()->type()->cast<TupleType>())) {
160     return;
161   }
162 
163   auto g = n->owningGraph();
164   auto tuple = toIValue(n->output()).value().toTuple();
165   const auto& tuple_elements = tuple->elements();
166   WithInsertPoint insert(n);
167   std::vector<Value*> elements;
168   for (const auto& elem : tuple_elements) {
169     auto constant = insertConstant(*n->owningGraph(), elem);
170     elements.push_back(constant);
171   }
172   auto tuple_type = n->output()->type()->expect<TupleType>();
173   auto tuple_construct = g->insertNode(n->owningGraph()->createTuple(
174       elements, tuple_type->schema() ? std::move(tuple_type) : nullptr));
175   tuple_construct->copyMetadata(n);
176 
177   // insert the tuple first before recursing on its elements, so that its
178   // elements will have a use
179   for (Value* elem : elements) {
180     RemoveTupleConstants(elem->node());
181   }
182 
183   n->replaceAllUsesWith(tuple_construct);
184 }
185 
flattenInputs(Node * n,Node * insert_point)186 static void flattenInputs(Node* n, Node* insert_point) {
187   // flatten the input list  op(a, tup, b) --> op(a, t0, t1, b)
188   for (size_t i = 0; i < n->inputs().size();) {
189     auto input = n->inputs()[i];
190     if (TupleTypePtr tt = input->type()->cast<TupleType>()) {
191       TORCH_CHECK(
192           (input->node()->kind() == prim::TupleConstruct),
193           "tuple use not matched to tuple construct. Instead found: ",
194           n->kind().toQualString());
195       if (supported_ops.count(n->kind()) > 0) {
196         if (n->kind() == prim::Loop) {
197           // This function supports all node types with blocks that take tuple
198           // inputs.
199           flattenTupleInLoopParams(n, i);
200         } else if (n->kind() == prim::Return) {
201           flattenTupleInBlockReturn(n, i);
202         } else {
203           for (size_t j = 0; j < tt->elements().size(); ++j) {
204             n->insertInput(i + 1 + j, input->node()->inputs().at(j));
205           }
206           n->removeInput(i);
207         }
208         // note: no update to i
209         // since tuples might be nested we need to recursively scan
210         // the new flattened inputs
211       } else {
212         TORCH_WARN(
213             "tuple appears in op inputs, but this op does not forward tuples, ",
214             "unsupported kind: ",
215             n->kind().toQualString());
216         ++i;
217       }
218     } else {
219       ++i;
220     }
221   }
222 }
223 
flattenOutputs(Node * n,Node * insert_point)224 static void flattenOutputs(Node* n, Node* insert_point) {
225   // flatten the outputs list
226   auto& graph = *n->owningGraph();
227   for (size_t i = 0; i < n->outputs().size();) {
228     Value* output = n->outputs()[i];
229     if (!output->hasUses()) {
230       ++i;
231       continue;
232     }
233 
234     // (a, b, tup, c) -> (a, b, t0, t1, c)
235     // and:
236     //    tup = (t0, t1)
237     // is placed at the current insertion point
238     if (TupleTypePtr tt = output->type()->cast<TupleType>()) {
239       if (supported_ops.count(n->kind()) > 0) {
240         for (const auto j : c10::irange(tt->elements().size())) {
241           n->insertOutput(i + 1 + j)->setType(tt->elements()[j]);
242         }
243         auto new_tup =
244             graph.createTuple(n->outputs().slice(i + 1, tt->elements().size()));
245         new_tup->copyMetadata(n);
246         new_tup->insertBefore(insert_point);
247         insert_point = new_tup;
248         output->replaceAllUsesWith(new_tup->output());
249         n->eraseOutput(i);
250         // note: no update to i to handle nested tuples
251       } else {
252         TORCH_WARN(
253             "tuple appears in the op outputs, but this op does not forward tuples, ",
254             "unsupported kind: ",
255             n->kind().toQualString());
256         ++i;
257       }
258     } else {
259       ++i;
260     }
261   }
262 }
263 
VisitNode(Node * n,Node * insert_point)264 static void VisitNode(Node* n, Node* insert_point) {
265   // tuple construction operators will become dead when the unpacks are replaced
266   if (n->kind() == prim::TupleConstruct) {
267     return;
268   }
269   // note: changing the second argument to false changes this pass from a
270   // complete lowering pass to one that removes tuples when possible. When
271   // tuples are first-class in the interpreter, we should still run this pass to
272   // remove extraneous uses
273   if (n->kind() == prim::TupleUnpack || n->kind() == prim::TupleIndex ||
274       n->kind() == prim::TupleSlice) {
275     removeTupleNodes(n, /*must_remove_tuples*/ true);
276     return;
277   }
278   flattenInputs(n, insert_point);
279   for (auto b : n->blocks()) {
280     LowerAllTuples(b);
281   }
282   flattenOutputs(n, insert_point);
283 }
284 
LowerAllTuples(Block * block)285 static void LowerAllTuples(Block* block) {
286   // tuples in parameter lists of a block behave exactly the same as
287   // _outputs_ of normal instructions, since the param_node represents the
288   // parameters as outputs, we can handle it by simply visiting the node
289   VisitNode(block->param_node(), *block->nodes().begin());
290   for (auto it = block->nodes().begin(), end = block->nodes().end();
291        it != end;) {
292     auto n = *it++;
293     RemoveTupleConstants(n);
294     VisitNode(n, *it);
295   }
296   // tuples in return lists of blocks behave exactly the same as
297   // _inputs_ of normal instructions, so we can use VisitNode here as well
298   // insert_point is null because it will never be used since return nodes
299   // have no outputs
300   VisitNode(block->return_node(), nullptr);
301 }
302 
EnsureNoTuples(ArrayRef<Value * > values)303 static void EnsureNoTuples(ArrayRef<Value*> values) {
304   for (Value* v : values) {
305     TORCH_CHECK(
306         v->type()->kind() != TypeKind::TupleType, "Couldn't lower all tuples.");
307   }
308 }
309 
EnsureNoTuples(Block * block)310 static void EnsureNoTuples(Block* block) {
311   for (Node* n : block->nodes()) {
312     for (Block* b : n->blocks()) {
313       EnsureNoTuples(b);
314     }
315     EnsureNoTuples(n->outputs());
316   }
317 }
318 
LowerAllTuples(const std::shared_ptr<Graph> & graph)319 void LowerAllTuples(const std::shared_ptr<Graph>& graph) {
320   LowerAllTuples(graph->block());
321   GRAPH_DUMP("After LowerAllTuples: ", graph);
322   EliminateDeadCode(graph->block());
323   EnsureNoTuples(graph->block());
324 }
325 
LowerSimpleTuples(Block * block)326 void LowerSimpleTuples(Block* block) {
327   for (auto n : block->nodes()) {
328     removeTupleNodes(n, /*must_remove_tuples*/ false);
329     for (auto b : n->blocks()) {
330       LowerSimpleTuples(b);
331     }
332   }
333 }
334 
LowerSimpleTuples(const std::shared_ptr<Graph> & graph)335 void LowerSimpleTuples(const std::shared_ptr<Graph>& graph) {
336   LowerSimpleTuples(graph->block());
337   GRAPH_DUMP("After LowerSimpleTuples: ", graph);
338   EliminateDeadCode(graph);
339 }
340 } // namespace torch::jit
341