1 #include <torch/csrc/jit/passes/lower_tuples.h>
2
3 #include <ATen/core/functional.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/ir/constants.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/dead_code_elimination.h>
9
10 #include <utility>
11
12 namespace torch::jit {
13
14 namespace {
15
16 // operators where we expect to find tuples as inputs/outputs
17 // this is to assert we are only doing modifications when we know
18 // we can flatten tuples
19 std::unordered_set<Symbol> supported_ops = {
20 prim::If,
21 prim::Loop,
22 prim::Uninitialized,
23 prim::TupleUnpack,
24 prim::TupleConstruct,
25 prim::TupleIndex,
26 prim::TupleSlice,
27 prim::Param,
28 prim::Return,
29 prim::PythonOp,
30 aten::format,
31 prim::Uninitialized,
32 aten::__getitem__};
33
34 // Flatten block inputs and insert a tuple construct in the block
flattenTupleInLoopParams(Node * n,size_t index)35 static void flattenTupleInLoopParams(Node* n, size_t index) {
36 auto input = n->inputs().at(index);
37 TupleTypePtr tt = input->type()->cast<TupleType>();
38 TORCH_INTERNAL_ASSERT(tt);
39
40 Block* block = n->blocks().at(0);
41 Node* block_node = n;
42
43 std::vector<Value*> new_node_inputs = {};
44 auto new_construct_node =
45 block->prependNode(block->owningGraph()->create(prim::TupleConstruct));
46 for (size_t j = 0; j < tt->elements().size(); ++j) {
47 auto new_block_in = block->insertInput(index + j);
48 new_construct_node->addInput(new_block_in);
49 block_node->insertInput(index + j + 1, input->node()->inputs().at(j));
50 }
51 new_construct_node->output()->setType(block->inputs().at(index - 1)->type());
52 new_construct_node->copyMetadata(n);
53 block->inputs().at(index - 1)->replaceAllUsesWith(
54 new_construct_node->output());
55 block->eraseInput(index - 1);
56 block_node->removeInput(index);
57 }
58
59 // Flatten tuple outputs of the block node and append a TupleConstruct
60 // node after the block node if there is an outer block.
flattenTupleInBlockReturn(Node * n,size_t index)61 static void flattenTupleInBlockReturn(Node* n, size_t index) {
62 auto input = n->inputs().at(index);
63 Block* block = n->owningBlock();
64 Node* block_node = block->owningNode();
65 Node* new_construct_node = nullptr;
66 TupleTypePtr tt = input->type()->cast<TupleType>();
67 TORCH_INTERNAL_ASSERT(tt);
68
69 // 1- Add flattened tuple to block outputs
70 for (size_t j = 0; j < tt->elements().size(); ++j) {
71 block->insertOutput(index + j + 1, input->node()->inputs().at(j));
72 }
73 block->eraseOutput(index);
74
75 if (block_node == nullptr)
76 return;
77 // 2- For uses of the block node in the outer block,
78 // flatten the blocknode outputs and insert a tuple construct
79 // to replace that.
80 // Loop block has an extra element (iter counter)
81 if (block_node->kind() == prim::Loop)
82 index = index - 1;
83 auto tuple_output = block_node->outputs().at(index);
84 // When node has multiple blocks, do not flatten outputs on the second block
85 // again
86 if (!(tuple_output->type()->cast<TupleType>()))
87 return;
88
89 new_construct_node = block->owningGraph()->create(prim::TupleConstruct);
90 new_construct_node->insertAfter(block_node);
91 for (size_t j = 0; j < tt->elements().size(); ++j) {
92 auto new_block_out = block_node->insertOutput(index + j + 1);
93 new_construct_node->addInput(new_block_out);
94 }
95 // Replace the block node with the new TupleConstruct node
96 new_construct_node->output()->setType(tuple_output->type());
97 new_construct_node->copyMetadata(block_node);
98 tuple_output->replaceAllUsesWith(new_construct_node->output());
99 block_node->eraseOutput(index);
100 }
101
removeTupleNodes(Node * n,bool must_remove_tuples)102 void removeTupleNodes(Node* n, bool must_remove_tuples) {
103 if (n->kind() != prim::TupleUnpack && n->kind() != prim::TupleIndex &&
104 n->kind() != prim::TupleSlice) {
105 return;
106 }
107 // tuple index has two inputs, tuple and index
108 auto construct_node = n->inputs().at(0)->node();
109 if (construct_node->kind() != prim::TupleConstruct) {
110 if (must_remove_tuples) {
111 AT_ERROR(n->kind().toQualString(), " not matched to tuple construct");
112 }
113 return;
114 }
115 if (n->kind() == prim::TupleUnpack) {
116 for (size_t i = 0; i < n->outputs().size(); ++i) {
117 n->outputs()[i]->replaceAllUsesWith(construct_node->inputs().at(i));
118 }
119 } else if (n->kind() == prim::TupleIndex) {
120 auto idx = n->inputs().at(1);
121 auto maybe_int = constant_as<int64_t>(idx);
122 if (!maybe_int) {
123 if (must_remove_tuples) {
124 AT_ERROR(n->sourceRange(), "tuple index with non-constant index");
125 }
126 return;
127 }
128 auto int_idx = *maybe_int;
129 size_t len = construct_node->output()->type()->containedTypes().size();
130 if (int_idx < 0) {
131 int_idx += len;
132 }
133 // currently, we allow non-constant tuple index if the tuple is of one type.
134 // so we need to check bounds here
135 if (int_idx >= 0 && static_cast<size_t>(int_idx) < len) {
136 n->output()->replaceAllUsesWith(construct_node->inputs().at(int_idx));
137 }
138 } else if (n->kind() == prim::TupleSlice) {
139 std::vector<Value*> values;
140 int64_t beg = n->i(attr::beg);
141 int64_t end = n->i(attr::end);
142 for (int64_t i = beg; i < end; i += 1) {
143 values.push_back(construct_node->inputs().at(i));
144 }
145 auto graph = n->owningGraph();
146 auto tuple_out = graph->createTuple(values);
147 tuple_out->copyMetadata(n);
148 WithInsertPoint insert(n);
149 graph->insertNode(tuple_out);
150 n->output()->replaceAllUsesWith(tuple_out->output());
151 }
152 }
153 } // anonymous namespace
154
155 static void LowerAllTuples(Block* block);
156
RemoveTupleConstants(Node * n)157 static void RemoveTupleConstants(Node* n) {
158 if (!(n->kind() == prim::Constant &&
159 n->output()->type()->cast<TupleType>())) {
160 return;
161 }
162
163 auto g = n->owningGraph();
164 auto tuple = toIValue(n->output()).value().toTuple();
165 const auto& tuple_elements = tuple->elements();
166 WithInsertPoint insert(n);
167 std::vector<Value*> elements;
168 for (const auto& elem : tuple_elements) {
169 auto constant = insertConstant(*n->owningGraph(), elem);
170 elements.push_back(constant);
171 }
172 auto tuple_type = n->output()->type()->expect<TupleType>();
173 auto tuple_construct = g->insertNode(n->owningGraph()->createTuple(
174 elements, tuple_type->schema() ? std::move(tuple_type) : nullptr));
175 tuple_construct->copyMetadata(n);
176
177 // insert the tuple first before recursing on its elements, so that its
178 // elements will have a use
179 for (Value* elem : elements) {
180 RemoveTupleConstants(elem->node());
181 }
182
183 n->replaceAllUsesWith(tuple_construct);
184 }
185
flattenInputs(Node * n,Node * insert_point)186 static void flattenInputs(Node* n, Node* insert_point) {
187 // flatten the input list op(a, tup, b) --> op(a, t0, t1, b)
188 for (size_t i = 0; i < n->inputs().size();) {
189 auto input = n->inputs()[i];
190 if (TupleTypePtr tt = input->type()->cast<TupleType>()) {
191 TORCH_CHECK(
192 (input->node()->kind() == prim::TupleConstruct),
193 "tuple use not matched to tuple construct. Instead found: ",
194 n->kind().toQualString());
195 if (supported_ops.count(n->kind()) > 0) {
196 if (n->kind() == prim::Loop) {
197 // This function supports all node types with blocks that take tuple
198 // inputs.
199 flattenTupleInLoopParams(n, i);
200 } else if (n->kind() == prim::Return) {
201 flattenTupleInBlockReturn(n, i);
202 } else {
203 for (size_t j = 0; j < tt->elements().size(); ++j) {
204 n->insertInput(i + 1 + j, input->node()->inputs().at(j));
205 }
206 n->removeInput(i);
207 }
208 // note: no update to i
209 // since tuples might be nested we need to recursively scan
210 // the new flattened inputs
211 } else {
212 TORCH_WARN(
213 "tuple appears in op inputs, but this op does not forward tuples, ",
214 "unsupported kind: ",
215 n->kind().toQualString());
216 ++i;
217 }
218 } else {
219 ++i;
220 }
221 }
222 }
223
flattenOutputs(Node * n,Node * insert_point)224 static void flattenOutputs(Node* n, Node* insert_point) {
225 // flatten the outputs list
226 auto& graph = *n->owningGraph();
227 for (size_t i = 0; i < n->outputs().size();) {
228 Value* output = n->outputs()[i];
229 if (!output->hasUses()) {
230 ++i;
231 continue;
232 }
233
234 // (a, b, tup, c) -> (a, b, t0, t1, c)
235 // and:
236 // tup = (t0, t1)
237 // is placed at the current insertion point
238 if (TupleTypePtr tt = output->type()->cast<TupleType>()) {
239 if (supported_ops.count(n->kind()) > 0) {
240 for (const auto j : c10::irange(tt->elements().size())) {
241 n->insertOutput(i + 1 + j)->setType(tt->elements()[j]);
242 }
243 auto new_tup =
244 graph.createTuple(n->outputs().slice(i + 1, tt->elements().size()));
245 new_tup->copyMetadata(n);
246 new_tup->insertBefore(insert_point);
247 insert_point = new_tup;
248 output->replaceAllUsesWith(new_tup->output());
249 n->eraseOutput(i);
250 // note: no update to i to handle nested tuples
251 } else {
252 TORCH_WARN(
253 "tuple appears in the op outputs, but this op does not forward tuples, ",
254 "unsupported kind: ",
255 n->kind().toQualString());
256 ++i;
257 }
258 } else {
259 ++i;
260 }
261 }
262 }
263
VisitNode(Node * n,Node * insert_point)264 static void VisitNode(Node* n, Node* insert_point) {
265 // tuple construction operators will become dead when the unpacks are replaced
266 if (n->kind() == prim::TupleConstruct) {
267 return;
268 }
269 // note: changing the second argument to false changes this pass from a
270 // complete lowering pass to one that removes tuples when possible. When
271 // tuples are first-class in the interpreter, we should still run this pass to
272 // remove extraneous uses
273 if (n->kind() == prim::TupleUnpack || n->kind() == prim::TupleIndex ||
274 n->kind() == prim::TupleSlice) {
275 removeTupleNodes(n, /*must_remove_tuples*/ true);
276 return;
277 }
278 flattenInputs(n, insert_point);
279 for (auto b : n->blocks()) {
280 LowerAllTuples(b);
281 }
282 flattenOutputs(n, insert_point);
283 }
284
LowerAllTuples(Block * block)285 static void LowerAllTuples(Block* block) {
286 // tuples in parameter lists of a block behave exactly the same as
287 // _outputs_ of normal instructions, since the param_node represents the
288 // parameters as outputs, we can handle it by simply visiting the node
289 VisitNode(block->param_node(), *block->nodes().begin());
290 for (auto it = block->nodes().begin(), end = block->nodes().end();
291 it != end;) {
292 auto n = *it++;
293 RemoveTupleConstants(n);
294 VisitNode(n, *it);
295 }
296 // tuples in return lists of blocks behave exactly the same as
297 // _inputs_ of normal instructions, so we can use VisitNode here as well
298 // insert_point is null because it will never be used since return nodes
299 // have no outputs
300 VisitNode(block->return_node(), nullptr);
301 }
302
EnsureNoTuples(ArrayRef<Value * > values)303 static void EnsureNoTuples(ArrayRef<Value*> values) {
304 for (Value* v : values) {
305 TORCH_CHECK(
306 v->type()->kind() != TypeKind::TupleType, "Couldn't lower all tuples.");
307 }
308 }
309
EnsureNoTuples(Block * block)310 static void EnsureNoTuples(Block* block) {
311 for (Node* n : block->nodes()) {
312 for (Block* b : n->blocks()) {
313 EnsureNoTuples(b);
314 }
315 EnsureNoTuples(n->outputs());
316 }
317 }
318
LowerAllTuples(const std::shared_ptr<Graph> & graph)319 void LowerAllTuples(const std::shared_ptr<Graph>& graph) {
320 LowerAllTuples(graph->block());
321 GRAPH_DUMP("After LowerAllTuples: ", graph);
322 EliminateDeadCode(graph->block());
323 EnsureNoTuples(graph->block());
324 }
325
LowerSimpleTuples(Block * block)326 void LowerSimpleTuples(Block* block) {
327 for (auto n : block->nodes()) {
328 removeTupleNodes(n, /*must_remove_tuples*/ false);
329 for (auto b : n->blocks()) {
330 LowerSimpleTuples(b);
331 }
332 }
333 }
334
LowerSimpleTuples(const std::shared_ptr<Graph> & graph)335 void LowerSimpleTuples(const std::shared_ptr<Graph>& graph) {
336 LowerSimpleTuples(graph->block());
337 GRAPH_DUMP("After LowerSimpleTuples: ", graph);
338 EliminateDeadCode(graph);
339 }
340 } // namespace torch::jit
341