1 #include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
2
3 #include <ATen/InitialTensorOptions.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/jit/passes/dead_code_elimination.h>
7 #include <torch/csrc/jit/passes/onnx/helper.h>
8 #include <torch/csrc/jit/passes/onnx/peephole.h>
9 #include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
10
11 namespace torch::jit {
12
13 namespace {
14 const int ONNX_OPSET_13 = 13;
15 const int ONNX_TYPE_BOOL = 9;
16
CreateCastToBoolNode(Value * val,Graph * graph)17 Node* CreateCastToBoolNode(Value* val, Graph* graph) {
18 Node* cast_node = graph->create(c10::onnx::Cast);
19 cast_node->addInput(val);
20 cast_node->i_(attr::to, ONNX_TYPE_BOOL);
21 cast_node->output()->setType(BoolType::get());
22 return cast_node;
23 }
24
InsertCastForCond(Value * cond_val,Graph * graph,Node * consumer_node,int opset_version)25 Node* InsertCastForCond(
26 Value* cond_val,
27 Graph* graph,
28 Node* consumer_node,
29 int opset_version) {
30 // prev: cond_val -> consumer_node
31 // after: cond_val -> cast -> consumer_node
32 // NOTE: The cast is required because operators like PyTorch Greater/Less
33 // return tensor in type torch.uint8. However the type for condition
34 // input in ONNX Loop must be bool.
35 Node* cast_node = CreateCastToBoolNode(cond_val, graph);
36 cast_node->insertBefore(consumer_node);
37
38 consumer_node->replaceInputWith(cond_val, cast_node->output());
39 const ParamMap empty_params_dict = {};
40 ONNXShapeTypeInference(cast_node, empty_params_dict, opset_version);
41 return cast_node;
42 }
43
IsCondCastRequired(Value * cond_val)44 bool IsCondCastRequired(Value* cond_val) {
45 const auto& type = cond_val->type();
46 if (auto tt = type->cast<TensorType>()) {
47 if (auto scalar_type = tt->scalarType()) {
48 return *scalar_type != c10::kBool;
49 }
50 }
51 return !type->isSubtypeOf(*BoolType::get());
52 }
53
IsErasableSequence(const Node * loop_node,size_t i)54 bool IsErasableSequence(const Node* loop_node, size_t i) {
55 TORCH_INTERNAL_ASSERT(loop_node->blocks().size() == 1);
56 auto* sub_block = loop_node->blocks()[0];
57 auto* seq_node = sub_block->outputs()[i - 1]->node();
58 auto* in_val = sub_block->inputs()[i];
59
60 if (seq_node->kind() != ::c10::onnx::SequenceInsert) {
61 return false;
62 }
63
64 if (seq_node->inputs().size() == 3) {
65 // Non-default insert position is not supported.
66 return false;
67 }
68
69 if (seq_node->input(0) != in_val) {
70 // Only SequenceInsert that applies on loop-carried sequence is supported.
71 return false;
72 }
73
74 const auto* init_seq_node = loop_node->inputs()[i]->node();
75 const auto init_seq_node_kind = init_seq_node->kind();
76 if ((init_seq_node_kind != ::c10::onnx::SequenceEmpty) &&
77 (init_seq_node_kind != ::c10::prim::ListConstruct ||
78 !init_seq_node->inputs().empty())) {
79 // Initial sequence must be empty.
80 return false;
81 }
82
83 if (seq_node->output()->uses().size() != 1) {
84 // The sequence is not supported to be used elsewhere inside the sub-block.
85 return false;
86 }
87
88 return true;
89 }
90
91 // ONNX::Loop does not support Sequence type as loop-carried dependencies. Only
92 // tensors are supported. This pass converts Sequence loop-carried dependencies
93 // to scan_outputs. In opset 11, only the below pattern is supported.
94 //
95 // PTIR graph:
96 // ...
97 // %res.1 : Tensor[] = prim::ListConstruct()
98 // %res : Tensor[] = prim::Loop(%11, %22, %res.1)
99 // block0(%i.1 : Tensor, %res.6 : Tensor[]):
100 // ...
101 // %res.3 : Tensor[] = aten::append(%res.6, %17)
102 // -> (%22, %res.3)
103 // return (%res.3)
104 //
105 // ONNX graph:
106 // ...
107 // %res : Tensor = onnx::Loop(%11, %22)
108 // block0(%i.1 : Tensor):
109 // ...
110 // -> (%22, %17)
111 // %res_seq : Tensor[] = onnx::SplitToSequence[keepdims=0](%res)
112 // return (%res_seq)
ConvertSequenceDependencies(Node * node,int opset_version)113 std::vector<Value*> ConvertSequenceDependencies(Node* node, int opset_version) {
114 if (node->kind() != ::c10::onnx::Loop) {
115 return node->outputs().vec();
116 }
117
118 if (opset_version >= ONNX_OPSET_13) {
119 // Sequence type as loop-carried dependencies should be supported by ONNX
120 // ospet 13.
121 return node->outputs().vec();
122 }
123
124 auto* loop_node = node;
125
126 TORCH_INTERNAL_ASSERT(loop_node->blocks().size() == 1);
127 auto* sub_block = loop_node->blocks()[0];
128
129 std::vector<size_t> idx_to_remove;
130 std::vector<Value*> new_outputs;
131 // ONNX Loop node:
132 // sub-block inputs are (iter, cond, loop-carried dependencies)
133 // sub-block outputs are ( cond, loop-carried dependencies, scan outputs)
134 // inputs are (iter, cond, loop-carried dependencies)
135 // outputs are ( loop-carried dependencies, scan outputs)
136 for (size_t i = 2; i < sub_block->inputs().size(); ++i) {
137 if (IsErasableSequence(loop_node, i)) {
138 auto* seq_node = sub_block->outputs()[i - 1]->node();
139 // Replace sequence output with the inserted element.
140 auto inserted_value = seq_node->input(1);
141 sub_block->return_node()->replaceInputWith(
142 seq_node->output(), inserted_value);
143
144 // Split the added scan_output back to expected tensor sequence.
145 auto loop_output = loop_node->output(i - 2);
146 Node* split_node =
147 loop_node->owningGraph()->create(c10::onnx::SplitToSequence);
148 loop_output->replaceAllUsesWith(split_node->output());
149 split_node->i_(attr::keepdims, 0);
150 split_node->addInput(loop_output);
151 split_node->insertAfter(loop_node);
152 split_node->output()->setType(loop_output->type());
153 split_node->copyMetadata(loop_node);
154
155 // Update loop output type.
156 loop_output->setType(c10::unshapedType(inserted_value->type()));
157
158 // The node that produces sequence should be safe to remove now.
159 seq_node->destroy();
160
161 idx_to_remove.push_back(i);
162 new_outputs.push_back(split_node->output());
163 } else {
164 new_outputs.push_back(loop_node->output(i - 2));
165 }
166 }
167
168 // Remove sequence outputs, and replace with scan outputs.
169 for (const auto i : c10::irange(idx_to_remove.size())) {
170 size_t idx = idx_to_remove[i] - i;
171
172 sub_block->eraseInput(idx);
173 loop_node->removeInput(idx);
174
175 // Swap output order. Move all scan outputs to the back.
176 sub_block->return_node()->addInput(
177 sub_block->return_node()->inputs().at(idx - 1));
178 sub_block->return_node()->removeInput(idx - 1);
179
180 auto loop_out = loop_node->addOutput();
181 loop_out->copyMetadata(loop_node->outputs().at(idx - 2));
182 loop_node->outputs().at(idx - 2)->replaceAllUsesWith(loop_out);
183 loop_node->eraseOutput(idx - 2);
184 }
185
186 return new_outputs;
187 }
188
ONNXOptionalNode(const OptionalTypePtr & opt_type,Graph * g)189 Node* ONNXOptionalNode(const OptionalTypePtr& opt_type, Graph* g) {
190 TORCH_INTERNAL_ASSERT(opt_type);
191 TypePtr elem_type = opt_type->getElementType();
192 Node* opt_node = g->create(::c10::onnx::Optional, 1);
193 opt_node->ty_(Symbol::attr("type"), elem_type);
194 opt_node->output()->setType(OptionalType::create(elem_type));
195 return opt_node;
196 }
197
198 // Replaces block output i with an onnx::Optional
199 // with `type` taken from opt_type. If and Loop Ops shares this function.
200 // 1. If Op: Needed when control flow has multiple branches, one of which
201 // is defined by `block` and returns a None and another branch
202 // returns not-None. The passed-in opt_type should be from the other branch.
203 // 2. Loop Op: insert Optional node before output, if input is Optional type
204 // or output type is None.
ReplaceBlockOutputWithOptional(const OptionalTypePtr & opt_type,Block * block,size_t i)205 void ReplaceBlockOutputWithOptional(
206 const OptionalTypePtr& opt_type,
207 Block* block,
208 size_t i) {
209 Node* opt_node = ONNXOptionalNode(opt_type, block->owningGraph());
210 opt_node->insertBefore(block->return_node());
211 Value* block_output = block->outputs().at(i);
212 // replace only the last value as Optional type only affects
213 // the value right before output
214 block_output->replaceAllUsesAfterNodeWith(opt_node, opt_node->output());
215 if (!block_output->type()->cast<NoneType>()) {
216 opt_node->addInput(block_output);
217 opt_node->copyMetadata(block_output->node());
218 }
219 }
220
221 // Resolving limitation from ONNX that the block output can not be
222 // a value from outside the block. Inserting an Identity node inside
223 // the block, linking with the value outside as workaround.
FixupONNXSubblockOutputs(Node * n)224 void FixupONNXSubblockOutputs(Node* n) {
225 for (Block* block : n->blocks()) {
226 for (Value* output : block->outputs()) {
227 if (output->node()->owningBlock() != block) {
228 Node* id_node = nullptr;
229 // Simplify graph by creating an empty optional rather than
230 // Identity(None). Also enables shape inference later on, since
231 // ONNX shape inference doesn't handle None.
232 if (output->type()->cast<NoneType>()) {
233 id_node = block->owningGraph()->create(c10::onnx::Optional);
234 } else {
235 id_node = block->owningGraph()->create(c10::onnx::Identity);
236 id_node->addInput(output);
237 }
238 id_node->insertBefore(block->return_node());
239 id_node->output()->copyMetadata(output);
240 id_node->copyMetadata(n);
241 block->return_node()->replaceInputWith(output, id_node->output());
242 }
243 }
244 }
245 }
246
247 // Infer type of optional inputs from outputs.
FixupONNXLoopBlockInputs(Node * n)248 void FixupONNXLoopBlockInputs(Node* n) {
249 for (Block* block : n->blocks()) {
250 for (const auto i : c10::irange(1, block->inputs().size())) {
251 // input i corresponds to output i until we run FixupONNXLoopNodeInputs.
252 Value* input_i = block->inputs().at(i);
253 if (input_i->type()->cast<OptionalType>() &&
254 !block->outputs().at(i)->type()->cast<OptionalType>()) {
255 auto [merged_type, inferred] = MergeInferredType(
256 input_i->type()->cast<OptionalType>()->getElementType(),
257 block->outputs().at(i)->type());
258 if (inferred) {
259 input_i->setType(OptionalType::create(merged_type));
260 }
261 }
262 }
263 }
264 }
265
266 // Replace None in outputs with Optional.
FixupONNXLoopBlockOutputs(Node * n)267 void FixupONNXLoopBlockOutputs(Node* n) {
268 for (Block* block : n->blocks()) {
269 // output 0 is continue_condition, never None.
270 for (const auto i : c10::irange(1, block->outputs().size())) {
271 // Two conditions that we need to replace block output with optional
272 // 1. output is NoneType
273 // 2. input is optional but output type is not
274 if ((block->outputs().at(i)->type()->cast<NoneType>()) ||
275 (block->inputs().at(i + 1)->type()->cast<OptionalType>() &&
276 !block->outputs().at(i)->type()->cast<OptionalType>())) {
277 ReplaceBlockOutputWithOptional(
278 // Output 0 is continue_condition.
279 // Inputs (0, 1) are (loop_counter, cond). So input i + 1
280 // corresponds to output i.
281 block->inputs().at(i + 1)->type()->cast<OptionalType>(),
282 block,
283 i);
284 }
285 }
286 }
287 FixupONNXSubblockOutputs(n);
288 }
289
FixupONNXLoopNodeInputs(Node * node,int opset_version)290 void FixupONNXLoopNodeInputs(Node* node, int opset_version) {
291 if (node->kind() != ::c10::onnx::Loop) {
292 return;
293 }
294
295 auto* graph = node->owningGraph();
296
297 // add cast to condition input outside the loop.
298 Value* cond_val = node->input(1);
299 if (IsCondCastRequired(cond_val)) {
300 auto* cast_node = InsertCastForCond(cond_val, graph, node, opset_version);
301 cast_node->copyMetadata(node);
302 }
303
304 // Setup Loop input cond and i.
305 TORCH_INTERNAL_ASSERT(node->blocks().size() == 1);
306 auto* sub_block = node->blocks().at(0);
307 Value* cond = sub_block->insertInput(1, "cond");
308 cond->setType(BoolType::get());
309
310 Value* i = sub_block->inputs().at(0);
311 i->setType(TensorType::fromNumberType(*IntType::get()));
312
313 // add cast to condition input inside the loop.
314 Value* next_cond_val = sub_block->outputs().at(0);
315 if (IsCondCastRequired(next_cond_val)) {
316 auto* cast_node = InsertCastForCond(
317 next_cond_val, graph, sub_block->return_node(), opset_version);
318 cast_node->copyMetadata(node);
319 }
320
321 // Inputs (0, 1) are (max_trip_count, start_condition). Skip them
322 // since they're never None or Optional.
323 for (const auto i : c10::irange(2, node->inputs().size())) {
324 Value* input = node->inputs().at(i);
325 OptionalTypePtr sub_block_input_optional =
326 sub_block->inputs().at(i)->type()->cast<OptionalType>();
327 // If loop input is not optional but block input is, wrap loop input with
328 // Optional. Happens when the loop takes in None and outputs not-None, or
329 // vice-versa.
330 if (!input->type()->cast<OptionalType>() && sub_block_input_optional) {
331 if (!input->type()->cast<NoneType>()) {
332 auto [merged_type, inferred] = MergeInferredType(
333 sub_block_input_optional->getElementType(), input->type());
334 if (inferred) {
335 sub_block_input_optional = OptionalType::create(merged_type);
336 sub_block->inputs().at(i)->setType(sub_block_input_optional);
337 }
338 }
339 Node* opt_node = ONNXOptionalNode(sub_block_input_optional, graph);
340 if (!input->type()->cast<NoneType>()) {
341 opt_node->addInput(input);
342 }
343 opt_node->insertBefore(node);
344 node->replaceInputWith(input, opt_node->output());
345 }
346 }
347 }
348 } // anonymous namespace
349
FixupONNXLoopNode(Node * node,int opset_version)350 std::vector<Value*> FixupONNXLoopNode(Node* node, int opset_version) {
351 auto output_size = node->outputs().size();
352 GRAPH_DEBUG("before FixupONNXLoopBlockInputs: ", *node->owningGraph());
353 FixupONNXLoopBlockInputs(node);
354 GRAPH_DEBUG("after FixupONNXLoopBlockInputs: ", *node->owningGraph());
355 FixupONNXLoopNodeInputs(node, opset_version);
356 GRAPH_DEBUG("after FixupONNXLoopNodeInputs: ", *node->owningGraph());
357 FixupONNXLoopBlockOutputs(node);
358 GRAPH_DEBUG("after FixupONNXLoopBlockOutputs: ", *node->owningGraph());
359 // NOTE: the output order is deliberately changed to match expected order
360 // since onnx loop requires scan outputs to be the last outputs.
361 auto new_outputs = ConvertSequenceDependencies(node, opset_version);
362 // Copy type of block output to node output.
363 FixupONNXControlflowNodeOutputs(node);
364 GRAPH_DEBUG("after FixupONNXControlflowNodeOutputs: ", *node->owningGraph());
365 TORCH_INTERNAL_ASSERT(output_size == new_outputs.size());
366 return new_outputs;
367 }
368
369 // Check if node is prim::Uninitialized,
370 // or output of prim::Uninitialized->onnx::Identity
IsUninitializedNode(Node * n)371 bool IsUninitializedNode(Node* n) {
372 if (n->kind() == ::c10::onnx::Identity &&
373 n->inputs()[0]->node()->kind() == prim::Uninitialized)
374 return true;
375 if (n->kind() == prim::Uninitialized)
376 return true;
377 return false;
378 }
379
380 // Infer shape and type of the uninitialized_output from the corresponding
381 // output of the other subblock. prim::Uninitialized node is proven to be
382 // unused. So replace this node with one of the inferred shape and type.
InferShapeTypeForUninitializedOutput(Graph * graph,Block * block,Value * uninitialized_output,Value * other_output,int opset_version)383 void InferShapeTypeForUninitializedOutput(
384 Graph* graph,
385 Block* block,
386 Value* uninitialized_output,
387 Value* other_output,
388 int opset_version) {
389 Node* const_node = nullptr;
390 if (auto output_type = other_output->type()->cast<TensorType>()) {
391 auto elem_type =
392 at::initialTensorOptions().dtype(output_type->scalarType());
393 const_node = graph->create(::c10::onnx::Constant, 1);
394
395 if (output_type->sizes().concrete_sizes().has_value()) {
396 auto size = output_type->sizes().concrete_sizes().value();
397 const_node->t_(attr::value, at::zeros(size, elem_type));
398 const_node->output()->setType(other_output->type());
399 } else {
400 const_node->t_(attr::value, at::zeros({}, elem_type));
401 const_node->output()->setType(
402 TensorType::create(*(output_type->scalarType()), at::kCPU, {}, {}));
403 }
404 } else if (auto output_type = other_output->type()->cast<ListType>()) {
405 TypePtr elem = output_type->getElementType();
406 const_node = graph->create(::c10::onnx::SequenceEmpty, 1);
407 if (elem->cast<TensorType>() &&
408 elem->cast<TensorType>()->scalarType().has_value()) {
409 auto scalar_type = elem->cast<TensorType>()->scalarType().value();
410 auto onnx_type = ATenTypeToOnnxType(scalar_type);
411 const_node->i_(attr::dtype, onnx_type);
412 const_node->output()->setType(other_output->type());
413 } else if (elem->cast<IntType>()) {
414 auto scalar_type = at::kLong;
415 auto onnx_type = ATenTypeToOnnxType(scalar_type);
416 const_node->i_(attr::dtype, onnx_type);
417 const_node->output()->setType(other_output->type());
418 } else {
419 TORCH_WARN(
420 "UninitializedOutput - Invalid elem Type of ListTensor found.");
421 const_node->output()->setType(other_output->type());
422 }
423 } else if (auto output_type = other_output->type()->cast<OptionalType>()) {
424 const_node = ONNXOptionalNode(output_type, graph);
425 }
426 TORCH_CHECK(
427 const_node,
428 "Inferring type for prim::Uninitialized node from " +
429 other_output->type()->repr_str() + " not supported.")
430 const ParamMap empty_params_dict = {};
431 ONNXShapeTypeInference(const_node, empty_params_dict, opset_version);
432 const_node->insertBefore(block->return_node());
433 const_node->copyMetadata(block->return_node());
434 uninitialized_output->replaceAllUsesWith(const_node->output());
435 uninitialized_output->node()->destroy();
436 }
437
438 // Corresponding outputs for ONNX If then and else subblocks should have
439 // same shape and type. This pass detects if prim::Uninitialized node
440 // appears as part of outputs of either of the subblocks, and infers
441 // shape and type from the corresponding output of the other subblock
442 // In the example graph below, shape and type of the subblock output %7
443 // for subblock 1 is inferred from %y.1. Shape and type of Subblock
444 // output %7 is inferred from %y.5.
445 //
446 // graph(%y.1 : Int(3:4, 4:1, requires_grad=0, device=cpu)):
447 // ...
448 // %7 : Tensor = prim::Uninitialized()
449 // %16 : bool, %17 : Tensor, %y.14 : Tensor = prim::If(%15) #
450 // test/onnx/test_pytorch_onnx_onnxruntime.py:614:20
451 // block0():
452 // %y.5 : Tensor = aten::add(%y.1, %3, %6) #
453 // test/onnx/test_pytorch_onnx_onnxruntime.py:615:28
454 // -> (%2, %7, %y.5)
455 // block1():
456 // -> (%1, %y.1, %7)
457 // ...
458
ONNXFixupUninitializedOutput(Node * node,int opset_version)459 void ONNXFixupUninitializedOutput(Node* node, int opset_version) {
460 if (node->kind() != ::c10::onnx::If) {
461 return;
462 }
463
464 GRAPH_DUMP("Graph before fixing If shape type: ", node->owningGraph());
465 auto* if_node = node;
466 auto* graph = if_node->owningGraph();
467
468 // Check if the input to ONNX If node is node Bool, and insert
469 // cast to Bool if needed.
470 if (!if_node->input()->type()->isSubtypeOf(*BoolType::get())) {
471 Node* cast_node =
472 InsertCastForCond(if_node->input(), graph, if_node, opset_version);
473 cast_node->copyMetadata(if_node);
474 }
475
476 Block* then_block = if_node->blocks()[0];
477 Block* else_block = if_node->blocks()[1];
478
479 // Infer shape and type for subblock outputs
480 TORCH_INTERNAL_ASSERT(
481 then_block->outputs().size() == else_block->outputs().size())
482 for (const auto i : c10::irange(else_block->outputs().size())) {
483 Value* then_block_output = then_block->outputs()[i];
484 Value* else_block_output = else_block->outputs()[i];
485
486 // If both subblocks have an uninitialized output, shape and type cannot
487 // be inferred.
488 TORCH_CHECK(
489 !(IsUninitializedNode(then_block_output->node()) &&
490 IsUninitializedNode(else_block_output->node())),
491 "Cannot infer shape and type for ONNX If with uninitialized output in both subblocks. Please check the model graph.");
492
493 if (IsUninitializedNode(then_block_output->node())) {
494 InferShapeTypeForUninitializedOutput(
495 graph,
496 then_block,
497 then_block_output,
498 else_block_output,
499 opset_version);
500 if_node->outputs()[i]->setType(then_block->outputs()[i]->type());
501 } else if (IsUninitializedNode(else_block_output->node())) {
502 InferShapeTypeForUninitializedOutput(
503 graph,
504 else_block,
505 else_block_output,
506 then_block_output,
507 opset_version);
508 if_node->outputs()[i]->setType(else_block->outputs()[i]->type());
509 }
510 }
511 }
512
ONNXMergeIfBlockOutputShapes(Node * node)513 void ONNXMergeIfBlockOutputShapes(Node* node) {
514 TORCH_INTERNAL_ASSERT(node->kind() == ::c10::onnx::If);
515 Block* then_block = node->blocks().at(0);
516 Block* else_block = node->blocks().at(1);
517
518 TORCH_INTERNAL_ASSERT(
519 then_block->outputs().size() == else_block->outputs().size())
520
521 auto findCommonShape =
522 [](const ::c10::SymbolicShape& a,
523 const ::c10::SymbolicShape& b) -> ::c10::SymbolicShape {
524 std::vector<::c10::ShapeSymbol> dims;
525 if (a.rank() && b.rank() && a.rank() == b.rank()) {
526 for (const auto j : c10::irange(a.rank().value())) {
527 if (a[j] == b[j]) {
528 dims.emplace_back(a[j]);
529 } else {
530 dims.emplace_back(::c10::ShapeSymbol::newSymbol());
531 }
532 }
533 return ::c10::SymbolicShape(dims);
534 }
535 if (a.rank() && a.rank().value() > 0) {
536 return a;
537 }
538 if (b.rank() && b.rank().value() > 0) {
539 return b;
540 }
541
542 return ::c10::SymbolicShape();
543 };
544
545 auto mergeTensorType =
546 [&findCommonShape](TensorTypePtr a, TensorTypePtr b) -> TensorTypePtr {
547 if (a && b) {
548 const auto& a_shape = a->symbolic_sizes();
549 const auto& b_shape = b->symbolic_sizes();
550 auto commonShape = findCommonShape(a_shape, b_shape);
551 return a->withSymbolicShapes(commonShape);
552 } else if (a) {
553 return a;
554 } else if (b) {
555 return b;
556 }
557 return nullptr;
558 };
559
560 auto mergeListType = [&mergeTensorType](
561 ListTypePtr a, ListTypePtr b) -> ListTypePtr {
562 if (a && b) {
563 auto a_tensor_type = a->getElementType()->cast<TensorType>();
564 auto b_tensor_type = b->getElementType()->cast<TensorType>();
565 auto tensor_type = mergeTensorType(a_tensor_type, b_tensor_type);
566 if (tensor_type) {
567 return a->withContained({tensor_type})->cast<ListType>();
568 }
569 // Both branches produce ListType without tensor shape.
570 return a;
571 } else if (a) {
572 return a;
573 } else if (b) {
574 return b;
575 }
576 return nullptr;
577 };
578
579 auto mergeOptionalType = [&mergeTensorType, &mergeListType](
580 OptionalTypePtr a,
581 OptionalTypePtr b) -> OptionalTypePtr {
582 if (a && b) {
583 if (a->getElementType()->cast<TensorType>()) {
584 auto a_tensor_type = a->getElementType()->cast<TensorType>();
585 auto b_tensor_type = b->getElementType()->cast<TensorType>();
586 auto tensor_type = mergeTensorType(a_tensor_type, b_tensor_type);
587 if (tensor_type) {
588 return a->withContained({tensor_type})->cast<OptionalType>();
589 }
590 // Both branches produce OptionalType without tensor shape.
591 return a;
592 } else if (a->getElementType()->cast<ListType>()) {
593 auto a_list_type = a->getElementType()->cast<ListType>();
594 auto b_list_type = b->getElementType()->cast<ListType>();
595 auto list_type = mergeListType(a_list_type, b_list_type);
596 if (list_type) {
597 return a->withContained({list_type})->cast<OptionalType>();
598 }
599 // Both branches produce OptionalType without tensor shape.
600 return a;
601 }
602 } else if (a) {
603 return a;
604 } else if (b) {
605 return b;
606 }
607 return nullptr;
608 };
609
610 for (const auto i : c10::irange(else_block->outputs().size())) {
611 Value* output_i = node->output(i);
612 auto then_type = then_block->outputs().at(i)->type();
613 auto else_type = else_block->outputs().at(i)->type();
614 auto then_tensor_type = then_type->cast<TensorType>();
615 auto else_tensor_type = else_type->cast<TensorType>();
616 auto then_list_type = then_type->cast<ListType>();
617 auto else_list_type = else_type->cast<ListType>();
618 auto then_optional_type = then_type->cast<OptionalType>();
619 auto else_optional_type = else_type->cast<OptionalType>();
620 auto then_none_type = then_type->cast<NoneType>();
621 auto else_none_type = else_type->cast<NoneType>();
622 if (then_tensor_type || else_tensor_type) {
623 if (TypePtr merged_type =
624 mergeTensorType(then_tensor_type, else_tensor_type)) {
625 if (else_optional_type || else_none_type || then_optional_type ||
626 then_none_type) {
627 merged_type = OptionalType::create(merged_type);
628 }
629 output_i->setType(merged_type);
630 }
631 } else if (then_list_type || else_list_type) {
632 if (TypePtr merged_type = mergeListType(then_list_type, else_list_type)) {
633 if (else_optional_type || else_none_type || then_optional_type ||
634 then_none_type) {
635 merged_type = OptionalType::create(merged_type);
636 }
637 output_i->setType(merged_type);
638 }
639 }
640
641 if (then_optional_type || else_optional_type) {
642 if (auto optional_type =
643 mergeOptionalType(then_optional_type, else_optional_type)) {
644 output_i->setType(optional_type);
645 // Both branches output types must match.
646 if (!then_optional_type) {
647 ReplaceBlockOutputWithOptional(optional_type, then_block, i);
648 } else if (!else_optional_type) {
649 ReplaceBlockOutputWithOptional(optional_type, else_block, i);
650 }
651 }
652 }
653
654 if (then_none_type && !else_optional_type) {
655 ReplaceBlockOutputWithOptional(
656 output_i->type()->cast<OptionalType>(), then_block, i);
657 }
658
659 if (else_none_type && !then_optional_type) {
660 ReplaceBlockOutputWithOptional(
661 output_i->type()->cast<OptionalType>(), else_block, i);
662 }
663 }
664 }
665
FixupONNXIfNode(Node * node,int opset_version)666 std::vector<Value*> FixupONNXIfNode(Node* node, int opset_version) {
667 if (node->kind() != ::c10::onnx::If) {
668 return node->outputs().vec();
669 }
670 GRAPH_DUMP("Graph before fixing controlflow: ", node->owningGraph());
671 FixupONNXSubblockOutputs(node);
672 ONNXFixupUninitializedOutput(node, opset_version);
673 ONNXMergeIfBlockOutputShapes(node);
674
675 GRAPH_DUMP("Graph after fixing controlflow: ", node->owningGraph());
676 return node->outputs().vec();
677 }
678
FixupONNXControlflowNode(Node * n,int opset_version)679 std::vector<Value*> FixupONNXControlflowNode(Node* n, int opset_version) {
680 switch (n->kind()) {
681 case ::c10::onnx::Loop: {
682 return FixupONNXLoopNode(n, opset_version);
683 }
684 case ::c10::onnx::If: {
685 return FixupONNXIfNode(n, opset_version);
686 }
687 default:
688 return n->outputs().vec();
689 }
690 }
691
FixupONNXControlflowNodeOutputs(Node * n)692 void FixupONNXControlflowNodeOutputs(Node* n) {
693 switch (n->kind()) {
694 case ::c10::onnx::Loop: {
695 Block* loop_block = n->blocks().at(0);
696 // inputs (0, 1) are (i, cond), remainder are carried outputs.
697 size_t loop_carried_output_size = loop_block->inputs().size() - 2;
698
699 for (auto i : c10::irange(n->outputs().size())) {
700 if (i < loop_carried_output_size) {
701 const TypePtr block_input_type =
702 loop_block->inputs().at(i + 2)->type();
703 const TypePtr block_output_type =
704 loop_block->outputs().at(i + 1)->type();
705 TypePtr type = block_output_type;
706 // Handle the case where a block input is Optional but the
707 // output is not (i.e. if the loop executes > 0 times, the
708 // output will not be None).
709 if (block_input_type->cast<OptionalType>() &&
710 !block_output_type->cast<OptionalType>()) {
711 type = OptionalType::create(block_output_type);
712 }
713 n->output(i)->setType(type);
714 } else {
715 // scan output, should be a Tensor type
716 TypePtr type = loop_block->outputs().at(i + 1)->type();
717 if (auto t_type = type->cast<TensorType>()) {
718 auto sizes = t_type->symbolic_sizes().sizes();
719 if (sizes.has_value()) {
720 sizes.value().emplace(
721 sizes.value().begin(), c10::ShapeSymbol::newSymbol());
722 type = t_type->withSymbolicShapes(sizes.value());
723 }
724 }
725 n->output(i)->setType(type);
726 }
727 }
728 break;
729 }
730 case ::c10::onnx::If: {
731 ONNXMergeIfBlockOutputShapes(n);
732 break;
733 }
734 default:
735 break;
736 }
737 }
738
739 } // namespace torch::jit
740