xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
2 
3 #include <ATen/InitialTensorOptions.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/jit/passes/dead_code_elimination.h>
7 #include <torch/csrc/jit/passes/onnx/helper.h>
8 #include <torch/csrc/jit/passes/onnx/peephole.h>
9 #include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
10 
11 namespace torch::jit {
12 
13 namespace {
14 const int ONNX_OPSET_13 = 13;
15 const int ONNX_TYPE_BOOL = 9;
16 
CreateCastToBoolNode(Value * val,Graph * graph)17 Node* CreateCastToBoolNode(Value* val, Graph* graph) {
18   Node* cast_node = graph->create(c10::onnx::Cast);
19   cast_node->addInput(val);
20   cast_node->i_(attr::to, ONNX_TYPE_BOOL);
21   cast_node->output()->setType(BoolType::get());
22   return cast_node;
23 }
24 
InsertCastForCond(Value * cond_val,Graph * graph,Node * consumer_node,int opset_version)25 Node* InsertCastForCond(
26     Value* cond_val,
27     Graph* graph,
28     Node* consumer_node,
29     int opset_version) {
30   // prev:  cond_val -> consumer_node
31   // after: cond_val -> cast -> consumer_node
32   // NOTE: The cast is required because operators like PyTorch Greater/Less
33   //       return tensor in type torch.uint8. However the type for condition
34   //       input in ONNX Loop must be bool.
35   Node* cast_node = CreateCastToBoolNode(cond_val, graph);
36   cast_node->insertBefore(consumer_node);
37 
38   consumer_node->replaceInputWith(cond_val, cast_node->output());
39   const ParamMap empty_params_dict = {};
40   ONNXShapeTypeInference(cast_node, empty_params_dict, opset_version);
41   return cast_node;
42 }
43 
IsCondCastRequired(Value * cond_val)44 bool IsCondCastRequired(Value* cond_val) {
45   const auto& type = cond_val->type();
46   if (auto tt = type->cast<TensorType>()) {
47     if (auto scalar_type = tt->scalarType()) {
48       return *scalar_type != c10::kBool;
49     }
50   }
51   return !type->isSubtypeOf(*BoolType::get());
52 }
53 
IsErasableSequence(const Node * loop_node,size_t i)54 bool IsErasableSequence(const Node* loop_node, size_t i) {
55   TORCH_INTERNAL_ASSERT(loop_node->blocks().size() == 1);
56   auto* sub_block = loop_node->blocks()[0];
57   auto* seq_node = sub_block->outputs()[i - 1]->node();
58   auto* in_val = sub_block->inputs()[i];
59 
60   if (seq_node->kind() != ::c10::onnx::SequenceInsert) {
61     return false;
62   }
63 
64   if (seq_node->inputs().size() == 3) {
65     // Non-default insert position is not supported.
66     return false;
67   }
68 
69   if (seq_node->input(0) != in_val) {
70     // Only SequenceInsert that applies on loop-carried sequence is supported.
71     return false;
72   }
73 
74   const auto* init_seq_node = loop_node->inputs()[i]->node();
75   const auto init_seq_node_kind = init_seq_node->kind();
76   if ((init_seq_node_kind != ::c10::onnx::SequenceEmpty) &&
77       (init_seq_node_kind != ::c10::prim::ListConstruct ||
78        !init_seq_node->inputs().empty())) {
79     // Initial sequence must be empty.
80     return false;
81   }
82 
83   if (seq_node->output()->uses().size() != 1) {
84     // The sequence is not supported to be used elsewhere inside the sub-block.
85     return false;
86   }
87 
88   return true;
89 }
90 
91 // ONNX::Loop does not support Sequence type as loop-carried dependencies. Only
92 // tensors are supported. This pass converts Sequence loop-carried dependencies
93 // to scan_outputs. In opset 11, only the below pattern is supported.
94 //
95 // PTIR graph:
96 //  ...
97 //  %res.1 : Tensor[] = prim::ListConstruct()
98 //  %res : Tensor[] = prim::Loop(%11, %22, %res.1)
99 //    block0(%i.1 : Tensor, %res.6 : Tensor[]):
100 //      ...
101 //      %res.3 : Tensor[] = aten::append(%res.6, %17)
102 //      -> (%22, %res.3)
103 //  return (%res.3)
104 //
105 // ONNX graph:
106 //  ...
107 //  %res : Tensor = onnx::Loop(%11, %22)
108 //    block0(%i.1 : Tensor):
109 //      ...
110 //      -> (%22, %17)
111 //  %res_seq : Tensor[] = onnx::SplitToSequence[keepdims=0](%res)
112 //  return (%res_seq)
ConvertSequenceDependencies(Node * node,int opset_version)113 std::vector<Value*> ConvertSequenceDependencies(Node* node, int opset_version) {
114   if (node->kind() != ::c10::onnx::Loop) {
115     return node->outputs().vec();
116   }
117 
118   if (opset_version >= ONNX_OPSET_13) {
119     // Sequence type as loop-carried dependencies should be supported by ONNX
120     // ospet 13.
121     return node->outputs().vec();
122   }
123 
124   auto* loop_node = node;
125 
126   TORCH_INTERNAL_ASSERT(loop_node->blocks().size() == 1);
127   auto* sub_block = loop_node->blocks()[0];
128 
129   std::vector<size_t> idx_to_remove;
130   std::vector<Value*> new_outputs;
131   // ONNX Loop node:
132   // sub-block inputs are  (iter, cond, loop-carried dependencies)
133   // sub-block outputs are (      cond, loop-carried dependencies, scan outputs)
134   // inputs are            (iter, cond, loop-carried dependencies)
135   // outputs are           (            loop-carried dependencies, scan outputs)
136   for (size_t i = 2; i < sub_block->inputs().size(); ++i) {
137     if (IsErasableSequence(loop_node, i)) {
138       auto* seq_node = sub_block->outputs()[i - 1]->node();
139       // Replace sequence output with the inserted element.
140       auto inserted_value = seq_node->input(1);
141       sub_block->return_node()->replaceInputWith(
142           seq_node->output(), inserted_value);
143 
144       // Split the added scan_output back to expected tensor sequence.
145       auto loop_output = loop_node->output(i - 2);
146       Node* split_node =
147           loop_node->owningGraph()->create(c10::onnx::SplitToSequence);
148       loop_output->replaceAllUsesWith(split_node->output());
149       split_node->i_(attr::keepdims, 0);
150       split_node->addInput(loop_output);
151       split_node->insertAfter(loop_node);
152       split_node->output()->setType(loop_output->type());
153       split_node->copyMetadata(loop_node);
154 
155       // Update loop output type.
156       loop_output->setType(c10::unshapedType(inserted_value->type()));
157 
158       // The node that produces sequence should be safe to remove now.
159       seq_node->destroy();
160 
161       idx_to_remove.push_back(i);
162       new_outputs.push_back(split_node->output());
163     } else {
164       new_outputs.push_back(loop_node->output(i - 2));
165     }
166   }
167 
168   // Remove sequence outputs, and replace with scan outputs.
169   for (const auto i : c10::irange(idx_to_remove.size())) {
170     size_t idx = idx_to_remove[i] - i;
171 
172     sub_block->eraseInput(idx);
173     loop_node->removeInput(idx);
174 
175     // Swap output order. Move all scan outputs to the back.
176     sub_block->return_node()->addInput(
177         sub_block->return_node()->inputs().at(idx - 1));
178     sub_block->return_node()->removeInput(idx - 1);
179 
180     auto loop_out = loop_node->addOutput();
181     loop_out->copyMetadata(loop_node->outputs().at(idx - 2));
182     loop_node->outputs().at(idx - 2)->replaceAllUsesWith(loop_out);
183     loop_node->eraseOutput(idx - 2);
184   }
185 
186   return new_outputs;
187 }
188 
ONNXOptionalNode(const OptionalTypePtr & opt_type,Graph * g)189 Node* ONNXOptionalNode(const OptionalTypePtr& opt_type, Graph* g) {
190   TORCH_INTERNAL_ASSERT(opt_type);
191   TypePtr elem_type = opt_type->getElementType();
192   Node* opt_node = g->create(::c10::onnx::Optional, 1);
193   opt_node->ty_(Symbol::attr("type"), elem_type);
194   opt_node->output()->setType(OptionalType::create(elem_type));
195   return opt_node;
196 }
197 
198 // Replaces block output i with an onnx::Optional
199 // with `type` taken from opt_type. If and Loop Ops shares this function.
200 // 1. If Op: Needed when control flow has multiple branches, one of which
201 // is defined by `block` and returns a None and another branch
202 // returns not-None. The passed-in opt_type should be from the other branch.
203 // 2. Loop Op: insert Optional node before output, if input is Optional type
204 // or output type is None.
ReplaceBlockOutputWithOptional(const OptionalTypePtr & opt_type,Block * block,size_t i)205 void ReplaceBlockOutputWithOptional(
206     const OptionalTypePtr& opt_type,
207     Block* block,
208     size_t i) {
209   Node* opt_node = ONNXOptionalNode(opt_type, block->owningGraph());
210   opt_node->insertBefore(block->return_node());
211   Value* block_output = block->outputs().at(i);
212   // replace only the last value as Optional type only affects
213   // the value right before output
214   block_output->replaceAllUsesAfterNodeWith(opt_node, opt_node->output());
215   if (!block_output->type()->cast<NoneType>()) {
216     opt_node->addInput(block_output);
217     opt_node->copyMetadata(block_output->node());
218   }
219 }
220 
221 // Resolving limitation from ONNX that the block output can not be
222 // a value from outside the block. Inserting an Identity node inside
223 // the block, linking with the value outside as workaround.
FixupONNXSubblockOutputs(Node * n)224 void FixupONNXSubblockOutputs(Node* n) {
225   for (Block* block : n->blocks()) {
226     for (Value* output : block->outputs()) {
227       if (output->node()->owningBlock() != block) {
228         Node* id_node = nullptr;
229         // Simplify graph by creating an empty optional rather than
230         // Identity(None). Also enables shape inference later on, since
231         // ONNX shape inference doesn't handle None.
232         if (output->type()->cast<NoneType>()) {
233           id_node = block->owningGraph()->create(c10::onnx::Optional);
234         } else {
235           id_node = block->owningGraph()->create(c10::onnx::Identity);
236           id_node->addInput(output);
237         }
238         id_node->insertBefore(block->return_node());
239         id_node->output()->copyMetadata(output);
240         id_node->copyMetadata(n);
241         block->return_node()->replaceInputWith(output, id_node->output());
242       }
243     }
244   }
245 }
246 
247 // Infer type of optional inputs from outputs.
FixupONNXLoopBlockInputs(Node * n)248 void FixupONNXLoopBlockInputs(Node* n) {
249   for (Block* block : n->blocks()) {
250     for (const auto i : c10::irange(1, block->inputs().size())) {
251       // input i corresponds to output i until we run FixupONNXLoopNodeInputs.
252       Value* input_i = block->inputs().at(i);
253       if (input_i->type()->cast<OptionalType>() &&
254           !block->outputs().at(i)->type()->cast<OptionalType>()) {
255         auto [merged_type, inferred] = MergeInferredType(
256             input_i->type()->cast<OptionalType>()->getElementType(),
257             block->outputs().at(i)->type());
258         if (inferred) {
259           input_i->setType(OptionalType::create(merged_type));
260         }
261       }
262     }
263   }
264 }
265 
266 // Replace None in outputs with Optional.
FixupONNXLoopBlockOutputs(Node * n)267 void FixupONNXLoopBlockOutputs(Node* n) {
268   for (Block* block : n->blocks()) {
269     // output 0 is continue_condition, never None.
270     for (const auto i : c10::irange(1, block->outputs().size())) {
271       // Two conditions that we need to replace block output with optional
272       // 1. output is NoneType
273       // 2. input is optional but output type is not
274       if ((block->outputs().at(i)->type()->cast<NoneType>()) ||
275           (block->inputs().at(i + 1)->type()->cast<OptionalType>() &&
276            !block->outputs().at(i)->type()->cast<OptionalType>())) {
277         ReplaceBlockOutputWithOptional(
278             // Output 0 is continue_condition.
279             // Inputs (0, 1) are (loop_counter, cond). So input i + 1
280             // corresponds to output i.
281             block->inputs().at(i + 1)->type()->cast<OptionalType>(),
282             block,
283             i);
284       }
285     }
286   }
287   FixupONNXSubblockOutputs(n);
288 }
289 
FixupONNXLoopNodeInputs(Node * node,int opset_version)290 void FixupONNXLoopNodeInputs(Node* node, int opset_version) {
291   if (node->kind() != ::c10::onnx::Loop) {
292     return;
293   }
294 
295   auto* graph = node->owningGraph();
296 
297   // add cast to condition input outside the loop.
298   Value* cond_val = node->input(1);
299   if (IsCondCastRequired(cond_val)) {
300     auto* cast_node = InsertCastForCond(cond_val, graph, node, opset_version);
301     cast_node->copyMetadata(node);
302   }
303 
304   // Setup Loop input cond and i.
305   TORCH_INTERNAL_ASSERT(node->blocks().size() == 1);
306   auto* sub_block = node->blocks().at(0);
307   Value* cond = sub_block->insertInput(1, "cond");
308   cond->setType(BoolType::get());
309 
310   Value* i = sub_block->inputs().at(0);
311   i->setType(TensorType::fromNumberType(*IntType::get()));
312 
313   // add cast to condition input inside the loop.
314   Value* next_cond_val = sub_block->outputs().at(0);
315   if (IsCondCastRequired(next_cond_val)) {
316     auto* cast_node = InsertCastForCond(
317         next_cond_val, graph, sub_block->return_node(), opset_version);
318     cast_node->copyMetadata(node);
319   }
320 
321   // Inputs (0, 1) are (max_trip_count, start_condition). Skip them
322   // since they're never None or Optional.
323   for (const auto i : c10::irange(2, node->inputs().size())) {
324     Value* input = node->inputs().at(i);
325     OptionalTypePtr sub_block_input_optional =
326         sub_block->inputs().at(i)->type()->cast<OptionalType>();
327     // If loop input is not optional but block input is, wrap loop input with
328     // Optional. Happens when the loop takes in None and outputs not-None, or
329     // vice-versa.
330     if (!input->type()->cast<OptionalType>() && sub_block_input_optional) {
331       if (!input->type()->cast<NoneType>()) {
332         auto [merged_type, inferred] = MergeInferredType(
333             sub_block_input_optional->getElementType(), input->type());
334         if (inferred) {
335           sub_block_input_optional = OptionalType::create(merged_type);
336           sub_block->inputs().at(i)->setType(sub_block_input_optional);
337         }
338       }
339       Node* opt_node = ONNXOptionalNode(sub_block_input_optional, graph);
340       if (!input->type()->cast<NoneType>()) {
341         opt_node->addInput(input);
342       }
343       opt_node->insertBefore(node);
344       node->replaceInputWith(input, opt_node->output());
345     }
346   }
347 }
348 } // anonymous namespace
349 
FixupONNXLoopNode(Node * node,int opset_version)350 std::vector<Value*> FixupONNXLoopNode(Node* node, int opset_version) {
351   auto output_size = node->outputs().size();
352   GRAPH_DEBUG("before FixupONNXLoopBlockInputs: ", *node->owningGraph());
353   FixupONNXLoopBlockInputs(node);
354   GRAPH_DEBUG("after FixupONNXLoopBlockInputs: ", *node->owningGraph());
355   FixupONNXLoopNodeInputs(node, opset_version);
356   GRAPH_DEBUG("after FixupONNXLoopNodeInputs: ", *node->owningGraph());
357   FixupONNXLoopBlockOutputs(node);
358   GRAPH_DEBUG("after FixupONNXLoopBlockOutputs: ", *node->owningGraph());
359   // NOTE: the output order is deliberately changed to match expected order
360   //       since onnx loop requires scan outputs to be the last outputs.
361   auto new_outputs = ConvertSequenceDependencies(node, opset_version);
362   // Copy type of block output to node output.
363   FixupONNXControlflowNodeOutputs(node);
364   GRAPH_DEBUG("after FixupONNXControlflowNodeOutputs: ", *node->owningGraph());
365   TORCH_INTERNAL_ASSERT(output_size == new_outputs.size());
366   return new_outputs;
367 }
368 
369 // Check if node is prim::Uninitialized,
370 // or output of prim::Uninitialized->onnx::Identity
IsUninitializedNode(Node * n)371 bool IsUninitializedNode(Node* n) {
372   if (n->kind() == ::c10::onnx::Identity &&
373       n->inputs()[0]->node()->kind() == prim::Uninitialized)
374     return true;
375   if (n->kind() == prim::Uninitialized)
376     return true;
377   return false;
378 }
379 
380 // Infer shape and type of the uninitialized_output from the corresponding
381 // output of the other subblock. prim::Uninitialized node is proven to be
382 // unused. So replace this node with one of the inferred shape and type.
InferShapeTypeForUninitializedOutput(Graph * graph,Block * block,Value * uninitialized_output,Value * other_output,int opset_version)383 void InferShapeTypeForUninitializedOutput(
384     Graph* graph,
385     Block* block,
386     Value* uninitialized_output,
387     Value* other_output,
388     int opset_version) {
389   Node* const_node = nullptr;
390   if (auto output_type = other_output->type()->cast<TensorType>()) {
391     auto elem_type =
392         at::initialTensorOptions().dtype(output_type->scalarType());
393     const_node = graph->create(::c10::onnx::Constant, 1);
394 
395     if (output_type->sizes().concrete_sizes().has_value()) {
396       auto size = output_type->sizes().concrete_sizes().value();
397       const_node->t_(attr::value, at::zeros(size, elem_type));
398       const_node->output()->setType(other_output->type());
399     } else {
400       const_node->t_(attr::value, at::zeros({}, elem_type));
401       const_node->output()->setType(
402           TensorType::create(*(output_type->scalarType()), at::kCPU, {}, {}));
403     }
404   } else if (auto output_type = other_output->type()->cast<ListType>()) {
405     TypePtr elem = output_type->getElementType();
406     const_node = graph->create(::c10::onnx::SequenceEmpty, 1);
407     if (elem->cast<TensorType>() &&
408         elem->cast<TensorType>()->scalarType().has_value()) {
409       auto scalar_type = elem->cast<TensorType>()->scalarType().value();
410       auto onnx_type = ATenTypeToOnnxType(scalar_type);
411       const_node->i_(attr::dtype, onnx_type);
412       const_node->output()->setType(other_output->type());
413     } else if (elem->cast<IntType>()) {
414       auto scalar_type = at::kLong;
415       auto onnx_type = ATenTypeToOnnxType(scalar_type);
416       const_node->i_(attr::dtype, onnx_type);
417       const_node->output()->setType(other_output->type());
418     } else {
419       TORCH_WARN(
420           "UninitializedOutput - Invalid elem Type of ListTensor found.");
421       const_node->output()->setType(other_output->type());
422     }
423   } else if (auto output_type = other_output->type()->cast<OptionalType>()) {
424     const_node = ONNXOptionalNode(output_type, graph);
425   }
426   TORCH_CHECK(
427       const_node,
428       "Inferring type for prim::Uninitialized node from " +
429           other_output->type()->repr_str() + " not supported.")
430   const ParamMap empty_params_dict = {};
431   ONNXShapeTypeInference(const_node, empty_params_dict, opset_version);
432   const_node->insertBefore(block->return_node());
433   const_node->copyMetadata(block->return_node());
434   uninitialized_output->replaceAllUsesWith(const_node->output());
435   uninitialized_output->node()->destroy();
436 }
437 
438 // Corresponding outputs for ONNX If then and else subblocks should have
439 // same shape and type. This pass detects if prim::Uninitialized node
440 // appears as part of outputs of either of the subblocks, and infers
441 // shape and type from the corresponding output of the other subblock
442 // In the example graph below, shape and type of the subblock output %7
443 // for subblock 1 is inferred from %y.1. Shape and type of Subblock
444 // output %7 is inferred from %y.5.
445 //
446 // graph(%y.1 : Int(3:4, 4:1, requires_grad=0, device=cpu)):
447 //   ...
448 //   %7 : Tensor = prim::Uninitialized()
449 //   %16 : bool, %17 : Tensor, %y.14 : Tensor = prim::If(%15) #
450 //   test/onnx/test_pytorch_onnx_onnxruntime.py:614:20
451 //     block0():
452 //       %y.5 : Tensor = aten::add(%y.1, %3, %6) #
453 //       test/onnx/test_pytorch_onnx_onnxruntime.py:615:28
454 //       -> (%2, %7, %y.5)
455 //     block1():
456 //       -> (%1, %y.1, %7)
457 //   ...
458 
ONNXFixupUninitializedOutput(Node * node,int opset_version)459 void ONNXFixupUninitializedOutput(Node* node, int opset_version) {
460   if (node->kind() != ::c10::onnx::If) {
461     return;
462   }
463 
464   GRAPH_DUMP("Graph before fixing If shape type: ", node->owningGraph());
465   auto* if_node = node;
466   auto* graph = if_node->owningGraph();
467 
468   // Check if the input to ONNX If node is node Bool, and insert
469   // cast to Bool if needed.
470   if (!if_node->input()->type()->isSubtypeOf(*BoolType::get())) {
471     Node* cast_node =
472         InsertCastForCond(if_node->input(), graph, if_node, opset_version);
473     cast_node->copyMetadata(if_node);
474   }
475 
476   Block* then_block = if_node->blocks()[0];
477   Block* else_block = if_node->blocks()[1];
478 
479   // Infer shape and type for subblock outputs
480   TORCH_INTERNAL_ASSERT(
481       then_block->outputs().size() == else_block->outputs().size())
482   for (const auto i : c10::irange(else_block->outputs().size())) {
483     Value* then_block_output = then_block->outputs()[i];
484     Value* else_block_output = else_block->outputs()[i];
485 
486     // If both subblocks have an uninitialized output, shape and type cannot
487     // be inferred.
488     TORCH_CHECK(
489         !(IsUninitializedNode(then_block_output->node()) &&
490           IsUninitializedNode(else_block_output->node())),
491         "Cannot infer shape and type for ONNX If with uninitialized output in both subblocks. Please check the model graph.");
492 
493     if (IsUninitializedNode(then_block_output->node())) {
494       InferShapeTypeForUninitializedOutput(
495           graph,
496           then_block,
497           then_block_output,
498           else_block_output,
499           opset_version);
500       if_node->outputs()[i]->setType(then_block->outputs()[i]->type());
501     } else if (IsUninitializedNode(else_block_output->node())) {
502       InferShapeTypeForUninitializedOutput(
503           graph,
504           else_block,
505           else_block_output,
506           then_block_output,
507           opset_version);
508       if_node->outputs()[i]->setType(else_block->outputs()[i]->type());
509     }
510   }
511 }
512 
ONNXMergeIfBlockOutputShapes(Node * node)513 void ONNXMergeIfBlockOutputShapes(Node* node) {
514   TORCH_INTERNAL_ASSERT(node->kind() == ::c10::onnx::If);
515   Block* then_block = node->blocks().at(0);
516   Block* else_block = node->blocks().at(1);
517 
518   TORCH_INTERNAL_ASSERT(
519       then_block->outputs().size() == else_block->outputs().size())
520 
521   auto findCommonShape =
522       [](const ::c10::SymbolicShape& a,
523          const ::c10::SymbolicShape& b) -> ::c10::SymbolicShape {
524     std::vector<::c10::ShapeSymbol> dims;
525     if (a.rank() && b.rank() && a.rank() == b.rank()) {
526       for (const auto j : c10::irange(a.rank().value())) {
527         if (a[j] == b[j]) {
528           dims.emplace_back(a[j]);
529         } else {
530           dims.emplace_back(::c10::ShapeSymbol::newSymbol());
531         }
532       }
533       return ::c10::SymbolicShape(dims);
534     }
535     if (a.rank() && a.rank().value() > 0) {
536       return a;
537     }
538     if (b.rank() && b.rank().value() > 0) {
539       return b;
540     }
541 
542     return ::c10::SymbolicShape();
543   };
544 
545   auto mergeTensorType =
546       [&findCommonShape](TensorTypePtr a, TensorTypePtr b) -> TensorTypePtr {
547     if (a && b) {
548       const auto& a_shape = a->symbolic_sizes();
549       const auto& b_shape = b->symbolic_sizes();
550       auto commonShape = findCommonShape(a_shape, b_shape);
551       return a->withSymbolicShapes(commonShape);
552     } else if (a) {
553       return a;
554     } else if (b) {
555       return b;
556     }
557     return nullptr;
558   };
559 
560   auto mergeListType = [&mergeTensorType](
561                            ListTypePtr a, ListTypePtr b) -> ListTypePtr {
562     if (a && b) {
563       auto a_tensor_type = a->getElementType()->cast<TensorType>();
564       auto b_tensor_type = b->getElementType()->cast<TensorType>();
565       auto tensor_type = mergeTensorType(a_tensor_type, b_tensor_type);
566       if (tensor_type) {
567         return a->withContained({tensor_type})->cast<ListType>();
568       }
569       // Both branches produce ListType without tensor shape.
570       return a;
571     } else if (a) {
572       return a;
573     } else if (b) {
574       return b;
575     }
576     return nullptr;
577   };
578 
579   auto mergeOptionalType = [&mergeTensorType, &mergeListType](
580                                OptionalTypePtr a,
581                                OptionalTypePtr b) -> OptionalTypePtr {
582     if (a && b) {
583       if (a->getElementType()->cast<TensorType>()) {
584         auto a_tensor_type = a->getElementType()->cast<TensorType>();
585         auto b_tensor_type = b->getElementType()->cast<TensorType>();
586         auto tensor_type = mergeTensorType(a_tensor_type, b_tensor_type);
587         if (tensor_type) {
588           return a->withContained({tensor_type})->cast<OptionalType>();
589         }
590         // Both branches produce OptionalType without tensor shape.
591         return a;
592       } else if (a->getElementType()->cast<ListType>()) {
593         auto a_list_type = a->getElementType()->cast<ListType>();
594         auto b_list_type = b->getElementType()->cast<ListType>();
595         auto list_type = mergeListType(a_list_type, b_list_type);
596         if (list_type) {
597           return a->withContained({list_type})->cast<OptionalType>();
598         }
599         // Both branches produce OptionalType without tensor shape.
600         return a;
601       }
602     } else if (a) {
603       return a;
604     } else if (b) {
605       return b;
606     }
607     return nullptr;
608   };
609 
610   for (const auto i : c10::irange(else_block->outputs().size())) {
611     Value* output_i = node->output(i);
612     auto then_type = then_block->outputs().at(i)->type();
613     auto else_type = else_block->outputs().at(i)->type();
614     auto then_tensor_type = then_type->cast<TensorType>();
615     auto else_tensor_type = else_type->cast<TensorType>();
616     auto then_list_type = then_type->cast<ListType>();
617     auto else_list_type = else_type->cast<ListType>();
618     auto then_optional_type = then_type->cast<OptionalType>();
619     auto else_optional_type = else_type->cast<OptionalType>();
620     auto then_none_type = then_type->cast<NoneType>();
621     auto else_none_type = else_type->cast<NoneType>();
622     if (then_tensor_type || else_tensor_type) {
623       if (TypePtr merged_type =
624               mergeTensorType(then_tensor_type, else_tensor_type)) {
625         if (else_optional_type || else_none_type || then_optional_type ||
626             then_none_type) {
627           merged_type = OptionalType::create(merged_type);
628         }
629         output_i->setType(merged_type);
630       }
631     } else if (then_list_type || else_list_type) {
632       if (TypePtr merged_type = mergeListType(then_list_type, else_list_type)) {
633         if (else_optional_type || else_none_type || then_optional_type ||
634             then_none_type) {
635           merged_type = OptionalType::create(merged_type);
636         }
637         output_i->setType(merged_type);
638       }
639     }
640 
641     if (then_optional_type || else_optional_type) {
642       if (auto optional_type =
643               mergeOptionalType(then_optional_type, else_optional_type)) {
644         output_i->setType(optional_type);
645         // Both branches output types must match.
646         if (!then_optional_type) {
647           ReplaceBlockOutputWithOptional(optional_type, then_block, i);
648         } else if (!else_optional_type) {
649           ReplaceBlockOutputWithOptional(optional_type, else_block, i);
650         }
651       }
652     }
653 
654     if (then_none_type && !else_optional_type) {
655       ReplaceBlockOutputWithOptional(
656           output_i->type()->cast<OptionalType>(), then_block, i);
657     }
658 
659     if (else_none_type && !then_optional_type) {
660       ReplaceBlockOutputWithOptional(
661           output_i->type()->cast<OptionalType>(), else_block, i);
662     }
663   }
664 }
665 
FixupONNXIfNode(Node * node,int opset_version)666 std::vector<Value*> FixupONNXIfNode(Node* node, int opset_version) {
667   if (node->kind() != ::c10::onnx::If) {
668     return node->outputs().vec();
669   }
670   GRAPH_DUMP("Graph before fixing controlflow: ", node->owningGraph());
671   FixupONNXSubblockOutputs(node);
672   ONNXFixupUninitializedOutput(node, opset_version);
673   ONNXMergeIfBlockOutputShapes(node);
674 
675   GRAPH_DUMP("Graph after fixing controlflow: ", node->owningGraph());
676   return node->outputs().vec();
677 }
678 
FixupONNXControlflowNode(Node * n,int opset_version)679 std::vector<Value*> FixupONNXControlflowNode(Node* n, int opset_version) {
680   switch (n->kind()) {
681     case ::c10::onnx::Loop: {
682       return FixupONNXLoopNode(n, opset_version);
683     }
684     case ::c10::onnx::If: {
685       return FixupONNXIfNode(n, opset_version);
686     }
687     default:
688       return n->outputs().vec();
689   }
690 }
691 
FixupONNXControlflowNodeOutputs(Node * n)692 void FixupONNXControlflowNodeOutputs(Node* n) {
693   switch (n->kind()) {
694     case ::c10::onnx::Loop: {
695       Block* loop_block = n->blocks().at(0);
696       // inputs (0, 1) are (i, cond), remainder are carried outputs.
697       size_t loop_carried_output_size = loop_block->inputs().size() - 2;
698 
699       for (auto i : c10::irange(n->outputs().size())) {
700         if (i < loop_carried_output_size) {
701           const TypePtr block_input_type =
702               loop_block->inputs().at(i + 2)->type();
703           const TypePtr block_output_type =
704               loop_block->outputs().at(i + 1)->type();
705           TypePtr type = block_output_type;
706           // Handle the case where a block input is Optional but the
707           // output is not (i.e. if the loop executes > 0 times, the
708           // output will not be None).
709           if (block_input_type->cast<OptionalType>() &&
710               !block_output_type->cast<OptionalType>()) {
711             type = OptionalType::create(block_output_type);
712           }
713           n->output(i)->setType(type);
714         } else {
715           // scan output, should be a Tensor type
716           TypePtr type = loop_block->outputs().at(i + 1)->type();
717           if (auto t_type = type->cast<TensorType>()) {
718             auto sizes = t_type->symbolic_sizes().sizes();
719             if (sizes.has_value()) {
720               sizes.value().emplace(
721                   sizes.value().begin(), c10::ShapeSymbol::newSymbol());
722               type = t_type->withSymbolicShapes(sizes.value());
723             }
724           }
725           n->output(i)->setType(type);
726         }
727       }
728       break;
729     }
730     case ::c10::onnx::If: {
731       ONNXMergeIfBlockOutputShapes(n);
732       break;
733     }
734     default:
735       break;
736   }
737 }
738 
739 } // namespace torch::jit
740