xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/onnx/constant_fold.h>
6 #include <torch/csrc/jit/passes/onnx/constant_map.h>
7 #include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
8 #include <torch/csrc/jit/passes/onnx/helper.h>
9 #include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
10 #include <torch/csrc/jit/python/python_arg_flatten.h>
11 #include <torch/csrc/jit/serialization/export.h>
12 #include <torch/csrc/jit/serialization/onnx.h>
13 #include <torch/csrc/utils/python_strings.h>
14 
15 #include <torch/csrc/onnx/diagnostics/diagnostics.h>
16 
17 #include <onnx/shape_inference/implementation.h>
18 #include <algorithm>
19 #include <cmath>
20 #include <iterator>
21 #include <limits>
22 #include <unordered_set>
23 #include <utility>
24 
25 namespace torch::jit {
26 
PyNone_Check(PyObject * o)27 inline bool PyNone_Check(PyObject* o) {
28   return o == Py_None;
29 }
30 
MergeInferredType(const TypePtr & existing_type,const TypePtr & inferred_type)31 std::pair<TypePtr, bool> MergeInferredType(
32     const TypePtr& existing_type,
33     const TypePtr& inferred_type) {
34   auto new_list_type = inferred_type->cast<ListType>();
35   auto use_inferred_type = false;
36   if (new_list_type) {
37     return std::make_pair(inferred_type, true);
38   }
39   auto new_tensor_type = inferred_type->cast<TensorType>();
40   auto old_tensor_type = existing_type->cast<TensorType>();
41 
42   if (new_tensor_type && old_tensor_type) {
43     if (!old_tensor_type->device()) {
44       // device not available means this is an invalid tensor type (most likely
45       // an empty one) return inferred type directly.
46       return std::make_pair(new_tensor_type, true);
47     }
48     auto type = old_tensor_type;
49     if (new_tensor_type->dim()) {
50       type = type->withSymbolicShapes(new_tensor_type->symbolic_sizes());
51       use_inferred_type = true;
52     }
53     if (new_tensor_type->scalarType().has_value()) {
54       type = type->withScalarType(new_tensor_type->scalarType());
55       use_inferred_type = true;
56     }
57     return std::make_pair(type, use_inferred_type);
58   }
59 
60   if (old_tensor_type) {
61     return std::make_pair(existing_type, false);
62   }
63 
64   auto old_list_type = existing_type->cast<ListType>();
65   if (new_tensor_type && old_list_type) {
66     if (new_tensor_type->sizes().isComplete()) {
67       return std::make_pair(inferred_type, true);
68     }
69     return std::make_pair(existing_type, false);
70   }
71 
72   return std::make_pair(inferred_type, true);
73 }
74 
MergeInferredTypeAndSetMap(Value * dest_v,const TypePtr & existing_type,const TypePtr & inferred_type)75 void MergeInferredTypeAndSetMap(
76     Value* dest_v,
77     const TypePtr& existing_type,
78     const TypePtr& inferred_type) {
79   auto [mergedType, inferred] = MergeInferredType(existing_type, inferred_type);
80   dest_v->setType(mergedType);
81   ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred);
82 }
83 
84 namespace {
85 namespace onnx_torch = ::torch::onnx;
86 namespace onnx = ::ONNX_NAMESPACE;
87 namespace diagnostics = ::torch::onnx::diagnostics;
88 
89 // SymbolDimMap is a Torch-to-ONNX shape look-up. This is built so it can be
90 // returned by the export function. During the export however, when we come
91 // across new ONNX shapes, the reverse look-up is needed. To avoid incurring
92 // a linear-time look-up, we maintain DimSymbolMap in parallel.
ONNXDimToShapeSymbol(const onnx::TensorShapeProto_Dimension & dim,SymbolDimMap & symbol_dim_map,DimSymbolMap & dim_symbol_map)93 c10::ShapeSymbol ONNXDimToShapeSymbol(
94     const onnx::TensorShapeProto_Dimension& dim,
95     SymbolDimMap& symbol_dim_map,
96     DimSymbolMap& dim_symbol_map) {
97   if (dim.has_dim_value()) {
98     return c10::ShapeSymbol::fromStaticSize(dim.dim_value());
99   }
100   std::optional<c10::ShapeSymbol> sym = std::nullopt;
101   if (dim.has_dim_param()) {
102     // If this param is already known, assign the same Symbol.
103     GRAPH_UPDATE("Got dim_param:", dim.dim_param());
104     auto maybe_symbol = dim_symbol_map.find(dim.dim_param());
105     if (maybe_symbol != dim_symbol_map.end()) {
106       sym = maybe_symbol->second;
107     }
108   }
109   if (!sym) {
110     sym = c10::ShapeSymbol::newSymbol();
111     // If dim.dim_param() is empty, no need to keep track
112     // because there won't be duplicates.
113     symbol_dim_map[sym.value()] = dim.dim_param();
114     dim_symbol_map[dim.dim_param()] = sym.value();
115   }
116   return sym.value();
117 }
118 
TorchTensorTypeFromONNX(const onnx::TypeProto_Tensor & onnx_tensor_type,SymbolDimMap & symbol_dim_map,DimSymbolMap & dim_symbol_map)119 TensorTypePtr TorchTensorTypeFromONNX(
120     const onnx::TypeProto_Tensor& onnx_tensor_type,
121     SymbolDimMap& symbol_dim_map,
122     DimSymbolMap& dim_symbol_map) {
123   std::optional<at::ScalarType> scalar_type;
124   if (onnx_tensor_type.has_elem_type()) {
125     scalar_type = ONNXTypeToATenType(onnx_tensor_type.elem_type());
126   }
127 
128   auto v_type = TensorType::create(
129       scalar_type,
130       at::kCPU,
131       c10::SymbolicShape(),
132       c10::VaryingShape<c10::Stride>{},
133       {});
134   if (onnx_tensor_type.has_shape()) {
135     std::vector<c10::ShapeSymbol> sizes;
136     const auto& onnx_shape = onnx_tensor_type.shape();
137 
138     for (const auto i : c10::irange(onnx_shape.dim_size())) {
139       sizes.emplace_back(ONNXDimToShapeSymbol(
140           onnx_shape.dim(i), symbol_dim_map, dim_symbol_map));
141     }
142     v_type = TensorType::create(scalar_type, at::kCPU, sizes.size(), {});
143     v_type = v_type->withSymbolicShapes(c10::SymbolicShape(sizes));
144 
145     if (v_type->sizes().concrete_sizes().has_value()) {
146       // Populate strides based on sizes info, if sizes are all static.
147       // Creating strides ensures yielding True for isCompleteTensor.
148       v_type = v_type->contiguous();
149     }
150   }
151 
152   return v_type;
153 }
154 
TorchListTypeFromONNX(const onnx::TypeProto_Sequence & onnx_sequence_type,SymbolDimMap & symbol_dim_map,DimSymbolMap & dim_symbol_map)155 ListTypePtr TorchListTypeFromONNX(
156     const onnx::TypeProto_Sequence& onnx_sequence_type,
157     SymbolDimMap& symbol_dim_map,
158     DimSymbolMap& dim_symbol_map) {
159   if (onnx_sequence_type.has_elem_type()) {
160     const auto& onnx_seq_elem_type = onnx_sequence_type.elem_type();
161     if (onnx_seq_elem_type.has_tensor_type()) {
162       const auto& onnx_tensor_type = onnx_seq_elem_type.tensor_type();
163       const auto v_tensor_type = TorchTensorTypeFromONNX(
164           onnx_tensor_type, symbol_dim_map, dim_symbol_map);
165       auto v_type = ListType::create(v_tensor_type);
166       return v_type;
167     }
168   }
169   return nullptr;
170 }
171 
UpdateTorchValueByOnnxValueInfo(Value * v,const onnx::ValueInfoProto & p_info,SymbolDimMap & symbol_dim_map,DimSymbolMap & dim_symbol_map)172 void UpdateTorchValueByOnnxValueInfo(
173     Value* v,
174     const onnx::ValueInfoProto& p_info,
175     SymbolDimMap& symbol_dim_map,
176     DimSymbolMap& dim_symbol_map) {
177   if (!p_info.has_type()) {
178     return;
179   }
180 
181   const auto& p_type = p_info.type();
182   if (p_type.has_tensor_type()) {
183     const auto torch_tensor_type = TorchTensorTypeFromONNX(
184         p_type.tensor_type(), symbol_dim_map, dim_symbol_map);
185     if (torch_tensor_type) {
186       MergeInferredTypeAndSetMap(v, v->type(), torch_tensor_type);
187     }
188   } else if (p_type.has_sequence_type()) {
189     const auto torch_list_type = TorchListTypeFromONNX(
190         p_type.sequence_type(), symbol_dim_map, dim_symbol_map);
191     if (torch_list_type) {
192       MergeInferredTypeAndSetMap(v, v->type(), torch_list_type);
193     }
194   }
195 }
196 
IsValidONNXControlflowNode(const Node * n)197 bool IsValidONNXControlflowNode(const Node* n) {
198   // Skip when block size is zero. This is when the node is being created,
199   // and doesn't have subblocks attached yet. Run shape inference for these
200   // nodes later, when the subgraph has already completed shape inferencing.
201   auto node_kind = n->kind();
202   if (node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) {
203     if (n->blocks().empty()) {
204       return false;
205     }
206   }
207 
208   return true;
209 }
210 
IsValidONNXNode(const Node * n)211 bool IsValidONNXNode(const Node* n) {
212   auto node_kind = n->kind();
213 
214   if (!node_kind.is_onnx()) {
215     // node kind is not ONNX, skipped.
216     return false;
217   }
218 
219   if (!IsValidONNXControlflowNode(n)) {
220     return false;
221   }
222 
223   for (auto b : n->blocks()) {
224     for (auto b_n : b->nodes()) {
225       if (!IsValidONNXNode(b_n)) {
226         return false;
227       }
228     }
229   }
230 
231   return true;
232 }
233 
CustomSettype(Node * node)234 bool CustomSettype(Node* node) {
235   // This is a helper function to decide if the non-ONNX node actually has
236   // custom setType from user
237   // Go through every symbolic_sizes and if any one of them is static, we say
238   // this is set by user. On the other hand, if all of them are * (dynamic), we
239   // take this node does not have given type, since unreliable nodes have *
240   // shape anyway.
241   auto all_output_has_type = [](Value* output) {
242     if (auto output_type = output->type()->cast<TensorType>()) {
243       if (auto sizes = output_type->symbolic_sizes().sizes()) {
244         return std::any_of(std::begin(*sizes), std::end(*sizes), [](auto size) {
245           return size.is_static();
246         });
247       }
248     }
249     return false;
250   };
251 
252   return std::all_of(
253       node->outputs().begin(), node->outputs().end(), all_output_has_type);
254 }
255 
CloneValueFromListConstruct(Value * v,const std::shared_ptr<Graph> & n_graph,int opset_version)256 Value* CloneValueFromListConstruct(
257     Value* v,
258     const std::shared_ptr<Graph>& n_graph,
259     int opset_version) {
260   auto lc_node = v->node();
261   TORCH_INTERNAL_ASSERT(lc_node->kind() == ::c10::prim::ListConstruct);
262   // In jit/passes/onnx/peephole.cpp::eraseListConstruct,
263   // prim::ListConstruct is converted to onnx::Concat. The conversion should
264   // eventually be moved to symbolic. For now, treat this operator as
265   // special case, and change from list type to tensor type. The scalar type
266   // is preserved. If the elemtype is Int, insert a onnx::Concat node into
267   // the graph.
268   TypePtr elem = v->type()->castRaw<ListType>()->getElementType();
269   std::optional<at::ScalarType> scalar_type = std::nullopt;
270   if (elem->cast<IntType>()) {
271     scalar_type = at::kLong;
272     if (isValidToTransformToONNXConcatNode(v->node())) {
273       auto concat_node = transformToONNXConcatNode(
274           n_graph.get(), v->node(), true, opset_version);
275       return concat_node->output();
276     }
277   } else if (elem->cast<FloatType>()) {
278     scalar_type = at::kFloat;
279   } else if (elem->cast<BoolType>()) {
280     scalar_type = at::kBool;
281   } else if (auto t_type = elem->cast<TensorType>()) {
282     scalar_type = t_type->scalarType();
283   }
284 
285   auto input = n_graph->addInput();
286   if (scalar_type) {
287     auto v_type = TensorType::create(
288         scalar_type.value(),
289         at::kCPU,
290         c10::SymbolicShape(),
291         c10::VaryingShape<c10::Stride>{},
292         {});
293     input->setType(v_type);
294   }
295   return input;
296 }
297 
298 // Clone the node n for the new graph.
CloneNodeToGraph(Node * n,std::shared_ptr<Graph> n_graph,const ParamMap & params_dict,int opset_version)299 Node* CloneNodeToGraph(
300     Node* n,
301     std::shared_ptr<Graph> n_graph,
302     const ParamMap& params_dict,
303     int opset_version) {
304   auto clone_node = n_graph->createClone(
305       n, [&n_graph, &params_dict, opset_version](Value* v) {
306         auto v_n = v->node();
307         switch (v_n->kind()) {
308           case ::c10::prim::Constant:
309           case ::c10::onnx::Constant: {
310             // Clone the input if it is constant.
311             auto constant_n = n_graph->insertNode(
312                 n_graph->createClone(v_n, [](Value* v) { return v; }));
313             return constant_n->output();
314           }
315           case ::c10::prim::ListConstruct: {
316             return CloneValueFromListConstruct(v, n_graph, opset_version);
317           }
318           case ::c10::prim::PackPadded: {
319             auto input = n_graph->addInput();
320             if (v == v_n->output(0)) {
321               // Only the first output requires this workaround.
322               // In `peephole` pass, user nodes are modified to consume the
323               // input instead.
324               input->copyMetadata(v_n->input(0));
325             } else {
326               input->copyMetadata(v);
327             }
328             return input;
329           }
330           default: {
331             // Try to lookup input value and insert it into the graph.
332             // If the input value is unknown, set it to graph input in the new
333             // graph, and copy over metadata, such as datatype and shape.
334             ::std::optional<at::Tensor> val = ::std::nullopt;
335             auto v0 = params_dict.find(v->debugName());
336             if (v0 != params_dict.end()) {
337               val = v0->second.toTensor();
338             } else {
339               val = ConstantValueMap::GetValue(v->debugName());
340             }
341 
342             if (val.has_value()) {
343               return n_graph
344                   ->insertNode(n_graph->create(::c10::onnx::Constant)
345                                    ->t_(attr::value, val.value()))
346                   ->output();
347             }
348             auto input = n_graph->addInput();
349             input->copyMetadata(v);
350             return input;
351           }
352         }
353       });
354   return clone_node;
355 }
356 
HasValidType(const TypePtr & type,const std::string & name)357 bool HasValidType(const TypePtr& type, const std::string& name) {
358   if (auto t_type = type->cast<TensorType>()) {
359     if (!t_type->scalarType().has_value()) {
360       GRAPH_UPDATE("Input ", name, " is missing tensor datatype.");
361       return false;
362     }
363   } else if (auto s_type = type->cast<ListType>()) {
364     auto e_type = s_type->getElementType();
365     return HasValidType(e_type, name);
366   } else if (auto o_type = type->cast<OptionalType>()) {
367     auto e_type = o_type->getElementType();
368     return HasValidType(e_type, name);
369   }
370   return true;
371 }
372 
IsGraphValidForInference(const std::shared_ptr<Graph> & graph)373 bool IsGraphValidForInference(const std::shared_ptr<Graph>& graph) {
374   // Verify if every input has type (either Tensor, Sequence or Optional) and
375   // scalar type. This is a requirement for ONNX graph inputs.
376   for (auto in : graph->inputs()) {
377     return HasValidType(in->type(), in->debugName());
378   }
379   return true;
380 }
381 
ConvertGraphToONNXProto(const std::shared_ptr<Graph> & graph,std::shared_ptr<onnx::ModelProto> & model_proto,SymbolDimMap & symbol_dim_map,DimSymbolMap & dim_symbol_map,int opset_version)382 void ConvertGraphToONNXProto(
383     const std::shared_ptr<Graph>& graph,
384     std::shared_ptr<onnx::ModelProto>& model_proto,
385     SymbolDimMap& symbol_dim_map,
386     DimSymbolMap& dim_symbol_map,
387     int opset_version) {
388   auto
389       [model_proto_tmp,
390        export_map,
391        new_symbol_dim_map,
392        val_use_external_data_format,
393        node_names] =
394           export_onnx(
395               graph,
396               {},
397               opset_version,
398               {},
399               false,
400               onnx_torch::OperatorExportTypes::ONNX,
401               true,
402               true,
403               {},
404               true,
405               false,
406               std::string());
407   model_proto = std::move(model_proto_tmp);
408   symbol_dim_map.insert(new_symbol_dim_map.begin(), new_symbol_dim_map.end());
409   for (const auto& pair : new_symbol_dim_map) {
410     dim_symbol_map[pair.second] = pair.first;
411   }
412   for (int i = 0; i < model_proto->graph().output_size(); ++i) {
413     model_proto->mutable_graph()->mutable_output(i)->clear_type();
414   }
415 }
416 
ComputeConstantFolding(Node * n,int opset_version)417 std::optional<at::Tensor> ComputeConstantFolding(Node* n, int opset_version) {
418   if (n->inputs().empty()) {
419     return std::nullopt;
420   }
421   std::vector<at::Tensor> inputTensorValues;
422   for (auto i : c10::irange(n->inputs().size())) {
423     if (TensorTypePtr input_type = n->input(i)->type()->cast<TensorType>()) {
424       if (!ConstantValueMap::HasValue(n->input(i)->debugName())) {
425         return std::nullopt;
426       }
427       auto tensor_value =
428           ConstantValueMap::GetValue(n->input(i)->debugName()).value();
429       inputTensorValues.emplace_back(tensor_value);
430     }
431   }
432   if (inputTensorValues.size() < n->inputs().size()) {
433     return std::nullopt;
434   }
435   try {
436     return onnx_constant_fold::runTorchBackendForOnnx(
437         n, inputTensorValues, opset_version);
438   } catch (const std::exception& ex) {
439     auto ex_str = std::string(ex.what());
440     ex_str = ex_str.substr(0, ex_str.find('\n'));
441     TORCH_WARN("Constant folding in symbolic shape inference fails: ", ex_str);
442     return std::nullopt;
443   }
444 }
445 
446 // Similar to the function above, but for symbolic shapes.
ComputeShapeFromReshape(Node * n,const c10::SymbolicShape & input_shape,const c10::SymbolicShape & shape,int opset_version)447 std::optional<::c10::SymbolicShape> ComputeShapeFromReshape(
448     Node* n,
449     const c10::SymbolicShape& input_shape,
450     const c10::SymbolicShape& shape,
451     int opset_version) {
452   std::vector<c10::ShapeSymbol> input_shape_vector =
453       input_shape.sizes().value();
454   std::vector<c10::ShapeSymbol> shape_vector = shape.sizes().value();
455   TORCH_INTERNAL_ASSERT(
456       !input_shape_vector.empty() || !shape_vector.empty(),
457       "Reshape node should have at least one input size > 0 when constant folding.");
458   if (shape_vector.empty()) {
459     return input_shape;
460   }
461   if (input_shape_vector.empty()) {
462     return shape;
463   }
464 
465   auto is_zero = [](c10::ShapeSymbol& ss) { return ss.value() == 0; };
466   auto it_0 = std::find_if(shape_vector.begin(), shape_vector.end(), is_zero);
467   bool shape_has_zero = it_0 != shape_vector.end();
468 
469   int minus_one_pos = -1;
470   for (auto i : c10::irange(shape_vector.size())) {
471     if (shape_vector[i].value() == -1) {
472       minus_one_pos = i;
473       break;
474     }
475   }
476 
477   int allowzero = 0;
478   if (opset_version >= 14 && n->hasAttributeS("allowzero")) {
479     allowzero = n->i(attr::allowzero);
480   }
481 
482   TORCH_CHECK(
483       !(shape_has_zero && allowzero == 1 && minus_one_pos != -1),
484       "0 and -1 cannot both be present in `Shape` input of `Reshape` node, when `allowzero=1`.");
485 
486   if (minus_one_pos == -1 && (!shape_has_zero || allowzero)) {
487     return shape;
488   }
489   std::vector<c10::ShapeSymbol> final_shape;
490   uint64_t shape_ratio = 1;
491   std::unordered_map<int64_t, int64_t> sym_map;
492   for (const c10::ShapeSymbol& input_shape : input_shape_vector) {
493     // input_shape.static_size() could be zero when torch.tensor([]) is used.
494     if (input_shape.is_static() && input_shape.static_size() != 0) {
495       if (shape_ratio >=
496           std::numeric_limits<uint64_t>::max() / input_shape.static_size()) {
497         TORCH_WARN(
498             "ComputeShapeFromReshape(), shape_ratio overflows, skip shape inference.");
499         return std::nullopt;
500       } else {
501         shape_ratio *= static_cast<uint64_t>(input_shape.static_size());
502       }
503     } else {
504       auto value = input_shape.value();
505       sym_map.emplace(value, 0).first->second += 1;
506     }
507   }
508   int shape_size = static_cast<int>(shape_vector.size());
509   for (const int i : c10::irange(shape_size)) {
510     if (i == minus_one_pos) {
511       continue;
512     }
513     c10::ShapeSymbol& target_shape = shape_vector[i];
514     if (target_shape.value() == 0) {
515       target_shape = input_shape_vector[i];
516     }
517     if (target_shape.is_static()) {
518       shape_ratio /= static_cast<uint64_t>(target_shape.static_size());
519     } else {
520       auto value = target_shape.value();
521       if (sym_map.find(value) == sym_map.end()) {
522         return std::nullopt;
523       }
524       sym_map[value]--;
525       if (sym_map[value] == 0) {
526         sym_map.erase(value);
527       }
528     }
529   }
530 
531   // sym_map is used to match shape symbols between the input and shape.
532   // If there is a mismatch, the output shape cannot be estimated.
533   if (!sym_map.empty()) {
534     return std::nullopt;
535   }
536 
537   TORCH_INTERNAL_ASSERT(
538       minus_one_pos != -1,
539       "There are no examples for shape_has_zero = true && minus_one_pos == -1.");
540 
541   for (const auto i : c10::irange(minus_one_pos)) {
542     c10::ShapeSymbol cur_shape(
543         shape_vector[i].value() == 0 ? input_shape_vector[i] : shape_vector[i]);
544     final_shape.push_back(cur_shape);
545   }
546   if (minus_one_pos != -1) {
547     final_shape.push_back(
548         c10::ShapeSymbol::fromStaticSize(static_cast<int64_t>(shape_ratio)));
549   }
550   for (auto i = minus_one_pos + 1; i < shape_size; i++) {
551     c10::ShapeSymbol cur_shape(
552         shape_vector[i].value() == 0 ? input_shape_vector[i] : shape_vector[i]);
553     final_shape.push_back(cur_shape);
554   }
555   c10::SymbolicShape final_shape_0(final_shape);
556   return final_shape_0;
557 }
558 
ComputeShapeFromExpand(const std::vector<::c10::ShapeSymbol> & input_shape,const std::vector<int64_t> & reshape)559 std::optional<::c10::SymbolicShape> ComputeShapeFromExpand(
560     const std::vector<::c10::ShapeSymbol>& input_shape,
561     const std::vector<int64_t>& reshape) {
562   for (const auto& it : reshape) {
563     if (it < 0) {
564       return std::nullopt;
565     }
566   }
567   std::vector<::c10::ShapeSymbol> final_shape;
568   if (input_shape.size() >= reshape.size()) {
569     final_shape = input_shape;
570   } else {
571     for (auto v : reshape) {
572       final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize(v));
573     }
574   }
575   auto min_size = std::min(input_shape.size(), reshape.size());
576   for (const auto i : c10::irange(min_size)) {
577     auto idx = final_shape.size() - i - 1;
578     auto input_shape_idx = input_shape.size() - i - 1;
579     auto reshape_idx = reshape.size() - i - 1;
580     if (input_shape[input_shape_idx].is_static()) {
581       auto input_shape_value = input_shape[input_shape_idx].static_size();
582       auto reshape_value = reshape[reshape_idx];
583       TORCH_INTERNAL_ASSERT(
584           input_shape_value == reshape_value || input_shape_value == 1 ||
585               reshape_value == 1,
586           "ONNX Expand input shape constraint not satisfied.");
587       final_shape[idx] = ::c10::ShapeSymbol::fromStaticSize(
588           std::max(input_shape_value, reshape_value));
589 
590     } else {
591       final_shape[idx] = ::c10::ShapeSymbol::newSymbol();
592     }
593   }
594   ::c10::SymbolicShape shape(final_shape);
595   return shape;
596 }
597 
ComputeShapeFromTile(const std::vector<::c10::ShapeSymbol> & input_shape,const std::vector<int64_t> & reshape)598 std::optional<::c10::SymbolicShape> ComputeShapeFromTile(
599     const std::vector<::c10::ShapeSymbol>& input_shape,
600     const std::vector<int64_t>& reshape) {
601   TORCH_INTERNAL_ASSERT(
602       input_shape.size() == reshape.size(),
603       "ONNX Tile input shapes do not match.");
604   for (const auto& it : reshape) {
605     if (it < 0) {
606       return std::nullopt;
607     }
608   }
609   std::vector<::c10::ShapeSymbol> final_shape;
610   final_shape.reserve(input_shape.size());
611   for (const auto i : c10::irange(input_shape.size())) {
612     if (input_shape[i].is_static()) {
613       final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize(
614           input_shape[i].static_size() * reshape[i]));
615     } else {
616       final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
617     }
618   }
619   ::c10::SymbolicShape shape(final_shape);
620   return shape;
621 }
622 
UpdateRank(Value * value,size_t rank)623 void UpdateRank(Value* value, size_t rank) {
624   ConstantValueMap::SetRank(value->debugName(), rank);
625   if (TensorTypePtr value_type = value->type()->cast<TensorType>()) {
626     std::optional<size_t> rank_opt = rank;
627     auto shape = ::c10::SymbolicShape(rank_opt);
628     value->setType(value_type->withSymbolicShapes(shape));
629   }
630 }
631 
UpdateShapeFromVector(Value * value,const std::vector<int64_t> & shape_size)632 void UpdateShapeFromVector(
633     Value* value,
634     const std::vector<int64_t>& shape_size) {
635   ::c10::SymbolicShape shape(shape_size);
636   ConstantValueMap::SetShape(value->debugName(), shape);
637   if (shape_size.empty()) {
638     UpdateRank(value, 0);
639     return;
640   }
641   ConstantValueMap::SetRank(value->debugName(), shape_size.size());
642   if (TensorTypePtr value_type = value->type()->cast<TensorType>()) {
643     value->setType(value_type->withSymbolicShapes(shape));
644   }
645 }
646 
UpdateShape(Value * value,const::c10::SymbolicShape & shape)647 void UpdateShape(Value* value, const ::c10::SymbolicShape& shape) {
648   ConstantValueMap::SetShape(value->debugName(), shape);
649   if (shape.rank().has_value()) {
650     auto rank = shape.rank().value();
651     if (rank == 0) {
652       UpdateRank(value, 0);
653       return;
654     }
655     ConstantValueMap::SetRank(value->debugName(), rank);
656     if (TensorTypePtr value_type = value->type()->cast<TensorType>()) {
657       value->setType(value_type->withSymbolicShapes(shape));
658     }
659   }
660 }
661 
UpdateShapeConstantValueMap(const Value * value,const::c10::SymbolicShape & shape)662 void UpdateShapeConstantValueMap(
663     const Value* value,
664     const ::c10::SymbolicShape& shape) {
665   ConstantValueMap::SetShape(value->debugName(), shape);
666   if (shape.rank().has_value()) {
667     auto rank = shape.rank().value();
668     ConstantValueMap::SetRank(value->debugName(), rank);
669   }
670 }
671 
GetValueFromListConstructNode(Node * lc_node)672 std::optional<std::vector<int64_t>> GetValueFromListConstructNode(
673     Node* lc_node) {
674   std::vector<int64_t> shape_size;
675   for (const auto& input : lc_node->inputs()) {
676     if (input->type()->cast<TensorType>() &&
677         ConstantValueMap::HasValue(input->debugName())) {
678       auto lc_value = ConstantValueMap::GetValue(input->debugName()).value();
679       if (lc_value.dim() == 0) {
680         int64_t lc_value_0 = lc_value.item<int64_t>();
681         shape_size.emplace_back(lc_value_0);
682       }
683     }
684   }
685   return lc_node->inputs().size() == shape_size.size()
686       ? std::optional<std::vector<int64_t>>(shape_size)
687       : std::nullopt;
688 }
689 
SetShapeValueFromListConstructNode(Node * lc_node)690 void SetShapeValueFromListConstructNode(Node* lc_node) {
691   std::vector<c10::ShapeSymbol> shape_size;
692   for (const auto& input : lc_node->inputs()) {
693     if (TensorTypePtr shape_type = input->type()->cast<TensorType>()) {
694       if (ConstantValueMap::HasValue(input->debugName())) {
695         auto lc_value = ConstantValueMap::GetValue(input->debugName()).value();
696         if (lc_value.dim() == 0) {
697           int64_t lc_value_0 = lc_value.item<int64_t>();
698           shape_size.emplace_back(c10::ShapeSymbol::fromStaticSize(lc_value_0));
699         }
700       } else if (ConstantValueMap::HasShapeValue(input->debugName())) {
701         auto lc_value =
702             ConstantValueMap::GetShapeValue(input->debugName()).value();
703         if (lc_value.rank() == 1U) {
704           shape_size.emplace_back(lc_value.at(0));
705         }
706       }
707     }
708   }
709   if (lc_node->inputs().size() == shape_size.size()) {
710     c10::SymbolicShape final_shape(shape_size);
711     ConstantValueMap::SetShapeValue(
712         lc_node->output()->debugName(), final_shape);
713   }
714 }
715 
Broadcast(const std::vector<::c10::ShapeSymbol> & input_shape_value_0,const std::vector<::c10::ShapeSymbol> & input_shape_value_1)716 std::vector<::c10::ShapeSymbol> Broadcast(
717     const std::vector<::c10::ShapeSymbol>& input_shape_value_0,
718     const std::vector<::c10::ShapeSymbol>& input_shape_value_1) {
719   size_t rank_0 = input_shape_value_0.size();
720   size_t rank_1 = input_shape_value_1.size();
721   size_t rank_max = std::max(rank_0, rank_1);
722   size_t rank_min = std::min(rank_0, rank_1);
723   std::vector<::c10::ShapeSymbol> final_shape;
724   final_shape.reserve(rank_max);
725   std::generate_n(
726       std::back_inserter(final_shape), rank_max, ::c10::ShapeSymbol::newSymbol);
727   for (auto idx : c10::irange(rank_min)) {
728     const c10::ShapeSymbol& ss_shape_0 = input_shape_value_0[rank_0 - 1 - idx];
729     const c10::ShapeSymbol& ss_shape_1 = input_shape_value_1[rank_1 - 1 - idx];
730     bool is_static_0 = ss_shape_0.is_static();
731     bool is_static_1 = ss_shape_1.is_static();
732     size_t shape_idx = rank_max - 1 - idx;
733     if (is_static_0 && is_static_1) {
734       int64_t static_0_sz = ss_shape_0.static_size();
735       int64_t static_1_sz = ss_shape_1.static_size();
736       // condition for corner case of 0d tensor
737       // 0d tensor with 1d tensor would give us 0d tensor
738       if (std::min(static_0_sz, static_1_sz) == 0) {
739         final_shape[shape_idx] = ::c10::ShapeSymbol::fromStaticSize(
740             std::min(static_0_sz, static_1_sz));
741       } else {
742         final_shape[shape_idx] = ::c10::ShapeSymbol::fromStaticSize(
743             std::max(static_0_sz, static_1_sz));
744       }
745     } else if (!is_static_0 && !is_static_1) {
746       if (ss_shape_0.value() == ss_shape_1.value()) {
747         final_shape[shape_idx] = ss_shape_0;
748       }
749     }
750   }
751   if (rank_0 < rank_1) {
752     for (size_t idx = rank_min; idx < rank_max; idx++) {
753       size_t shape_idx = rank_max - 1 - idx;
754       final_shape[shape_idx] = input_shape_value_1[shape_idx];
755     }
756   } else {
757     for (size_t idx = rank_min; idx < rank_max; idx++) {
758       size_t shape_idx = rank_max - 1 - idx;
759       final_shape[shape_idx] = input_shape_value_0[shape_idx];
760     }
761   }
762   return final_shape;
763 }
764 
ProcessBroadcastNode(Node * n)765 void ProcessBroadcastNode(Node* n) {
766   TORCH_INTERNAL_ASSERT(n->inputs().size() == 2);
767   if (ConstantValueMap::HasShape(n->input(0)->debugName()) &&
768       ConstantValueMap::HasShape(n->input(1)->debugName())) {
769     auto input_shape_0 = ConstantValueMap::GetShape(n->input(0)->debugName());
770     auto input_shape_value_0 = input_shape_0.value().sizes().value();
771     auto input_shape_1 = ConstantValueMap::GetShape(n->input(1)->debugName());
772     auto input_shape_value_1 = input_shape_1.value().sizes().value();
773     auto final_shape = Broadcast(input_shape_value_0, input_shape_value_1);
774     UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
775   }
776 }
777 
ProcessShapeForConcatNode(Node * n)778 void ProcessShapeForConcatNode(Node* n) {
779   int axis = n->i(attr::axis);
780   if (ConstantValueMap::HasRank(n->input(0)->debugName())) {
781     auto rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value();
782     size_t axis_adjust = 0;
783     if (axis >= 0) {
784       axis_adjust = static_cast<size_t>(axis);
785     } else {
786       axis_adjust = static_cast<size_t>(axis + static_cast<int>(rank));
787     }
788     std::vector<::c10::ShapeSymbol> final_shape;
789     final_shape.reserve(rank);
790     for (auto idx : c10::irange(rank)) {
791       if (idx == axis_adjust) {
792         auto flag = true;
793         int64_t size_total = 0;
794         for (auto input_idx : c10::irange(n->inputs().size())) {
795           if (ConstantValueMap::HasShape(n->input(input_idx)->debugName())) {
796             auto input_shape =
797                 ConstantValueMap::GetShape(n->input(input_idx)->debugName());
798             auto input_shape_value = input_shape.value().sizes();
799             auto shape_symbol = input_shape_value.value()[idx];
800             if (shape_symbol.is_static()) {
801               size_total += shape_symbol.static_size();
802             } else {
803               flag = false;
804               break;
805             }
806           }
807         }
808         if (flag) {
809           final_shape.emplace_back(
810               ::c10::ShapeSymbol::fromStaticSize(size_total));
811         } else {
812           final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
813         }
814       } else {
815         auto flag = false;
816         for (auto input_idx : c10::irange(n->inputs().size())) {
817           if (ConstantValueMap::HasShape(n->input(input_idx)->debugName())) {
818             auto input_shape =
819                 ConstantValueMap::GetShape(n->input(input_idx)->debugName());
820             auto input_shape_value = input_shape.value().sizes();
821             auto shape_symbol = input_shape_value.value()[idx];
822             if (shape_symbol.is_static()) {
823               final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize(
824                   shape_symbol.static_size()));
825               flag = true;
826               break;
827             }
828           }
829         }
830         if (!flag) {
831           final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
832         }
833       }
834     }
835     UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
836   }
837 }
838 
ProcessShapeValueForConcatNode(Node * n)839 void ProcessShapeValueForConcatNode(Node* n) {
840   auto rank = n->inputs().size();
841   std::vector<c10::ShapeSymbol> shape_size;
842   for (const auto& input : n->inputs()) {
843     if (ConstantValueMap::HasValue(input->debugName())) {
844       auto concat_value =
845           ConstantValueMap::GetValue(input->debugName()).value();
846       if (concat_value.dim() == 1 && concat_value.size(0) == 1) {
847         auto concat_value_0 = concat_value[0].item<int64_t>();
848         shape_size.emplace_back(
849             c10::ShapeSymbol::fromStaticSize(concat_value_0));
850       }
851     } else if (ConstantValueMap::HasShapeValue(input->debugName())) {
852       auto concat_value =
853           ConstantValueMap::GetShapeValue(input->debugName()).value();
854       if (concat_value.rank() == 1U) {
855         shape_size.emplace_back(concat_value.at(0));
856       }
857     }
858   }
859   if (rank == shape_size.size()) {
860     c10::SymbolicShape final_shape(shape_size);
861     ConstantValueMap::SetShapeValue(n->output(0)->debugName(), final_shape);
862   }
863 }
864 
ProcessConcatNode(Node * n)865 void ProcessConcatNode(Node* n) {
866   ProcessShapeForConcatNode(n);
867   ProcessShapeValueForConcatNode(n);
868 }
869 
ProcessMatMulNode(Node * n)870 void ProcessMatMulNode(Node* n) {
871   if (ConstantValueMap::HasShape(n->input(0)->debugName()) &&
872       ConstantValueMap::HasShape(n->input(1)->debugName())) {
873     auto input_shape_0 =
874         ConstantValueMap::GetShape(n->input(0)->debugName()).value();
875     auto input_shape_value_0 = input_shape_0.sizes().value();
876     auto input_shape_1 =
877         ConstantValueMap::GetShape(n->input(1)->debugName()).value();
878     auto input_shape_value_1 = input_shape_1.sizes().value();
879     size_t rank_0 = input_shape_value_0.size();
880     size_t rank_1 = input_shape_value_1.size();
881     // Handle inputs of rank 1 just like numpy.matmul:
882     // https://numpy.org/doc/stable/reference/generated/numpy.matmul.html
883     auto is_rank_0_1 = false;
884     if (rank_0 == 1) {
885       input_shape_value_0.insert(
886           input_shape_value_0.begin(), ::c10::ShapeSymbol::fromStaticSize(1));
887       rank_0 = 2;
888       is_rank_0_1 = true;
889     }
890     auto is_rank_1_1 = false;
891     if (rank_1 == 1) {
892       input_shape_value_1.emplace_back(::c10::ShapeSymbol::fromStaticSize(1));
893       rank_1 = 2;
894       is_rank_1_1 = true;
895     }
896     // Per https://pytorch.org/docs/stable/generated/torch.matmul.html
897     // the broadcasting logic only applies to the batch dimensions, and not the
898     // matrix dimensions so we remove the matrix dimensions which are the last 2
899     // dimensions before broadcasting
900     auto final_shape = Broadcast(
901         std::vector<::c10::ShapeSymbol>(
902             input_shape_value_0.begin(), input_shape_value_0.end() - 2),
903         std::vector<::c10::ShapeSymbol>(
904             input_shape_value_1.begin(), input_shape_value_1.end() - 2));
905     // add the last 2 dimensions back, unless they do not exist in the first
906     // place and inserted by this function Then apply [n,k]X[k,m]=[n,m], where
907     // n=input_shape_value_0[rank_0 - 2], m=input_shape_value_1[rank_1 - 1]
908     if (!is_rank_0_1) {
909       final_shape.emplace_back(input_shape_value_0[rank_0 - 2]);
910     }
911     if (!is_rank_1_1) {
912       final_shape.emplace_back(input_shape_value_1[rank_1 - 1]);
913     }
914     UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
915   }
916 }
917 
ProcessReduceNode(Node * n)918 void ProcessReduceNode(Node* n) {
919   if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
920     auto input_shape_0 = ConstantValueMap::GetShape(n->input(0)->debugName());
921     auto input_shape_value_0 = input_shape_0.value().sizes();
922     size_t rank_0 = input_shape_value_0.value().size();
923     std::vector<::c10::ShapeSymbol> final_shape;
924     std::vector<int64_t> axes_vector(rank_0);
925     if (n->hasAttributeS("axes")) {
926       axes_vector = n->is(attr::axes);
927     } else if (n->inputs().size() > 1) {
928       axes_vector =
929           ConstantValueMap::GetValueInto1DInt64Vector(n->input(1)->debugName());
930     } else {
931       std::iota(axes_vector.begin(), axes_vector.end(), 0);
932     }
933 
934     for (auto idx : c10::irange(axes_vector.size())) {
935       if (axes_vector[idx] < 0) {
936         axes_vector[idx] += rank_0;
937       }
938     }
939     final_shape.reserve(rank_0);
940     // ONNX keepdims defaults to 1 when not set.
941     int64_t keepdims = 1;
942     if (n->hasAttributeS("keepdims")) {
943       keepdims = n->i(attr::keepdims);
944     }
945     for (auto idx : c10::irange(rank_0)) {
946       auto it = std::find(axes_vector.begin(), axes_vector.end(), idx);
947       if (it != axes_vector.end()) {
948         if (keepdims != 0) {
949           final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize(1));
950         }
951       } else {
952         final_shape.emplace_back(input_shape_value_0.value()[idx]);
953       }
954     }
955     UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
956   }
957 }
958 
ProcessReshapeNode(Node * n,int opset_version)959 void ProcessReshapeNode(Node* n, int opset_version) {
960   const auto& input_name = n->input(0)->debugName();
961   const auto& shape_name = n->input(1)->debugName();
962 
963   // When `shape` input value is statically known, compute output shape.
964   if (ConstantValueMap::HasValue(shape_name)) {
965     auto static_shape_value =
966         ConstantValueMap::GetValueInto1DInt64Vector(shape_name);
967     auto symbolic_input_shape = ConstantValueMap::GetShape(input_name);
968     if (symbolic_input_shape && !static_shape_value.empty()) {
969       auto final_shape = ComputeShapeFromReshape(
970           n,
971           symbolic_input_shape.value(),
972           c10::SymbolicShape(static_shape_value),
973           opset_version);
974       if (final_shape) {
975         UpdateShape(n->output(), final_shape.value());
976         return;
977       }
978     }
979   }
980 
981   // When `shape` input value is symbolically known, compute output shape.
982   if (ConstantValueMap::HasShapeValue(shape_name) &&
983       ConstantValueMap::HasShape(input_name)) {
984     auto symbolic_input_shape = ConstantValueMap::GetShape(input_name).value();
985     auto symbolic_shape_value =
986         ConstantValueMap::GetShapeValue(shape_name).value();
987     auto final_shape = ComputeShapeFromReshape(
988         n, symbolic_input_shape, symbolic_shape_value, opset_version);
989     if (final_shape.has_value()) {
990       UpdateShape(n->output(), final_shape.value());
991       return;
992     }
993   }
994 
995   // Only shape of new shape is known, assign output rank.
996   if (ConstantValueMap::HasShape(shape_name)) {
997     auto output_rank = ConstantValueMap::GetShapeInto1DInt64Vector(shape_name);
998     if (output_rank.has_value()) {
999       TORCH_INTERNAL_ASSERT(output_rank.value().size() == 1);
1000       UpdateRank(n->output(), output_rank.value()[0]);
1001       return;
1002     }
1003   }
1004 
1005   // ListConstruct is handled at the beginning of ProcessConstantValueMap, no
1006   // further process here.
1007   if (TensorTypePtr shape_type = n->input(1)->type()->cast<TensorType>()) {
1008     // Set rank to Reshape output if possible.
1009     // From shape inference, we have:
1010     // %4236 : Float(*, device=cpu) = onnx::Transpose[perm=[0]](%4235)
1011     // %4237 : Long(2, strides=[1], device=cpu) = onnx::Concat[axis=0](%4232)
1012     // %4238 : FloatTensor(device=cpu) = onnx::Reshape(%4236, %4237)
1013     // We can have it as SymbolicShape with known rank:
1014     // %4238 : Float(*, *, strides=[2480, 1], requires_grad=0, device=cpu) =
1015     // onnx::Reshape(%4236, %4237)
1016     auto shape_type_dim = shape_type->dim();
1017     if (shape_type_dim.has_value()) {
1018       auto shape_type_size = shape_type->sizes()[0];
1019       if (shape_type_size.has_value()) {
1020         size_t rank = shape_type_size.value();
1021         UpdateRank(n->output(), rank);
1022       }
1023     }
1024   }
1025 }
1026 
ComputeShapeForSlice(const std::vector<c10::ShapeSymbol> & input_shape,const std::vector<int64_t> & start_vector,const std::vector<int64_t> & end_vector,const std::vector<int64_t> & axes_vector,const std::vector<int64_t> & step_vector)1027 c10::SymbolicShape ComputeShapeForSlice(
1028     const std::vector<c10::ShapeSymbol>& input_shape,
1029     const std::vector<int64_t>& start_vector,
1030     const std::vector<int64_t>& end_vector,
1031     const std::vector<int64_t>& axes_vector,
1032     const std::vector<int64_t>& step_vector) {
1033   TORCH_INTERNAL_ASSERT(axes_vector.size() <= input_shape.size());
1034   TORCH_INTERNAL_ASSERT(axes_vector.size() == start_vector.size());
1035   TORCH_INTERNAL_ASSERT(axes_vector.size() == end_vector.size());
1036   TORCH_INTERNAL_ASSERT(axes_vector.size() == step_vector.size());
1037   std::vector<c10::ShapeSymbol> final_shape;
1038   final_shape = input_shape;
1039   for (const auto idx : c10::irange(axes_vector.size())) {
1040     auto axis = axes_vector[idx];
1041     TORCH_INTERNAL_ASSERT(axis >= 0);
1042     if (!input_shape[axis].is_static()) {
1043       final_shape[axis] = c10::ShapeSymbol::newSymbol();
1044       continue;
1045     }
1046     auto input_shape_axis_value = input_shape[axis].static_size();
1047     auto cur_start = start_vector[idx];
1048     auto cur_end = end_vector[idx];
1049     auto cur_step = step_vector[idx];
1050     if (cur_start < -input_shape_axis_value) {
1051       cur_start = 0;
1052     } else if (cur_start < 0) {
1053       cur_start = input_shape_axis_value + cur_start;
1054     } else if (cur_start > input_shape_axis_value - 1) {
1055       cur_start = input_shape_axis_value;
1056     }
1057     if (cur_end < -input_shape_axis_value) {
1058       cur_end = -1;
1059     } else if (cur_end < 0) {
1060       cur_end = input_shape_axis_value + cur_end;
1061     } else if (cur_end > input_shape_axis_value - 1) {
1062       cur_end = input_shape_axis_value;
1063     }
1064     TORCH_INTERNAL_ASSERT(cur_step != 0);
1065     if (cur_step > 0) {
1066       final_shape[axis] = c10::ShapeSymbol::fromStaticSize(
1067           (cur_end - cur_start - 1) / cur_step + 1);
1068     } else {
1069       final_shape[axis] = c10::ShapeSymbol::fromStaticSize(
1070           (cur_start - cur_end - 1) / (-cur_step) + 1);
1071     }
1072   }
1073   return c10::SymbolicShape(final_shape);
1074 }
1075 
ProcessSliceNode(Node * n,int opset_version)1076 void ProcessSliceNode(Node* n, int opset_version) {
1077   bool valid = ConstantValueMap::HasShape(n->input(0)->debugName());
1078 
1079   // For opset version <= 9, starts, ends, axes, steps are attributes,
1080   // so their values are always valid.
1081   if (opset_version >= 10) {
1082     // We can only infer shapes if 'axes' is known.
1083     if (n->inputs().size() > 3) {
1084       valid = valid && ConstantValueMap::HasValue(n->input(3)->debugName());
1085     }
1086   }
1087 
1088   if (!valid) {
1089     if (ConstantValueMap::HasRank(n->input(0)->debugName())) {
1090       auto rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value();
1091       UpdateRank(n->output(), rank);
1092     }
1093     return;
1094   } else {
1095     auto shape_size_0 =
1096         ConstantValueMap::GetShape(n->input(0)->debugName()).value();
1097     if (shape_size_0.rank().has_value()) {
1098       auto input0_shape_value = shape_size_0.sizes().value();
1099 
1100       std::vector<int64_t> start_vector;
1101       std::vector<int64_t> end_vector;
1102       std::vector<int64_t> step_vector;
1103 
1104       std::vector<int64_t> axes_vector(input0_shape_value.size(), 0);
1105       for (const auto i : c10::irange(input0_shape_value.size())) {
1106         axes_vector[i] = i;
1107       }
1108       if (opset_version >= 10 && n->inputs().size() > 3) {
1109         axes_vector = ConstantValueMap::GetValueInto1DInt64Vector(
1110             n->input(3)->debugName());
1111       } else if (opset_version < 10 && n->hasAttributeS("axes")) {
1112         axes_vector = n->is(attr::axes);
1113       }
1114       for (auto& axis : axes_vector) {
1115         if (axis < 0) {
1116           axis += input0_shape_value.size();
1117         }
1118       }
1119 
1120       if (opset_version < 10) {
1121         start_vector = n->is(attr::starts);
1122         end_vector = n->is(attr::ends);
1123       } else {
1124         // If starts, ends, or step are unknown,
1125         // then mark all dimensions in 'axes' as unknown.
1126         std::vector<uint64_t> indices = {1U, 2U, 4U};
1127         bool start_end_step_known =
1128             std::all_of(indices.begin(), indices.end(), [&n](auto i) {
1129               return (i >= n->inputs().size()) ||
1130                   ConstantValueMap::HasValue(n->input(i)->debugName());
1131             });
1132         if (!start_end_step_known) {
1133           auto final_shape = input0_shape_value;
1134           for (const auto axis : axes_vector) {
1135             final_shape[axis] = c10::ShapeSymbol::newSymbol();
1136           }
1137           UpdateShape(n->output(), final_shape);
1138           return;
1139         }
1140 
1141         start_vector = ConstantValueMap::GetValueInto1DInt64Vector(
1142             n->input(1)->debugName());
1143         end_vector = ConstantValueMap::GetValueInto1DInt64Vector(
1144             n->input(2)->debugName());
1145         if (n->inputs().size() > 4) {
1146           step_vector = ConstantValueMap::GetValueInto1DInt64Vector(
1147               n->input(4)->debugName());
1148         }
1149       }
1150 
1151       if (step_vector.empty()) {
1152         step_vector = std::vector<int64_t>(axes_vector.size(), 1);
1153       }
1154 
1155       auto final_shape = ComputeShapeForSlice(
1156           input0_shape_value,
1157           start_vector,
1158           end_vector,
1159           axes_vector,
1160           step_vector);
1161       UpdateShape(n->output(), final_shape);
1162     }
1163   }
1164 }
1165 
ProcessUnchangeNode(Node * n)1166 void ProcessUnchangeNode(Node* n) {
1167   if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
1168     auto shape_size_0 =
1169         ConstantValueMap::GetShape(n->input(0)->debugName()).value();
1170     UpdateShape(n->output(), shape_size_0);
1171   }
1172 }
1173 
ProcessTimeSeriesNode(Node * n)1174 void ProcessTimeSeriesNode(Node* n) {
1175   auto input0_shape = ConstantValueMap::GetShape(n->input(0)->debugName());
1176   auto input1_shape = ConstantValueMap::GetShape(n->input(1)->debugName());
1177   if (!(input0_shape.has_value() && input1_shape.has_value())) {
1178     return;
1179   }
1180   auto input0_shape_value = input0_shape.value().sizes();
1181   auto input1_shape_value = input1_shape.value().sizes();
1182   c10::ShapeSymbol seq_length;
1183   c10::ShapeSymbol num_directions;
1184   c10::ShapeSymbol batch_size;
1185   c10::ShapeSymbol hidden_size;
1186   if (input0_shape_value.has_value()) {
1187     seq_length = input0_shape_value.value()[0];
1188     batch_size = input0_shape_value.value()[1];
1189   }
1190 
1191   if (input1_shape_value.has_value()) {
1192     num_directions = input1_shape_value.value()[0];
1193     if (input1_shape_value.value()[1].is_static()) {
1194       auto input1_value = input1_shape_value.value()[1].static_size();
1195       switch (n->kind()) {
1196         case ::c10::onnx::RNN:
1197           hidden_size = c10::ShapeSymbol::fromStaticSize(input1_value);
1198           break;
1199         case ::c10::onnx::LSTM:
1200           hidden_size = c10::ShapeSymbol::fromStaticSize(input1_value / 4);
1201           break;
1202         case ::c10::onnx::GRU:
1203           hidden_size = c10::ShapeSymbol::fromStaticSize(input1_value / 3);
1204           break;
1205         default:
1206           throw std::runtime_error(
1207               std::string() + "This is not a valid TimeSeries Node with type " +
1208               n->kind().toDisplayString());
1209       }
1210     } else {
1211       hidden_size = c10::ShapeSymbol::newSymbol();
1212     }
1213   }
1214 
1215   if (n->outputs().size() > 1) {
1216     std::vector<c10::ShapeSymbol> final_shape = {
1217         seq_length, num_directions, batch_size, hidden_size};
1218     UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
1219   }
1220   for (const auto idx : c10::irange(2U, 4U)) {
1221     if (n->outputs().size() > idx) {
1222       std::vector<c10::ShapeSymbol> final_shape = {
1223           num_directions, batch_size, hidden_size};
1224       UpdateShape(n->output(idx - 1), c10::SymbolicShape(final_shape));
1225     }
1226   }
1227 }
1228 
ProcessUnsqueezeNode(Node * n)1229 void ProcessUnsqueezeNode(Node* n) {
1230   TensorTypePtr output_type = n->output(0)->type()->cast<TensorType>();
1231   if (output_type == nullptr) {
1232     return;
1233   }
1234   if (output_type->dim().has_value() && output_type->dim().value() == 1 &&
1235       ConstantValueMap::HasShapeValue(n->input(0)->debugName())) {
1236     auto shape_value =
1237         ConstantValueMap::GetShapeValue(n->input(0)->debugName()).value();
1238     // When the scalar represents a shape, it is the same as the shape value
1239     // when it gets unsqueezed.
1240     ConstantValueMap::SetShapeValue(n->output()->debugName(), shape_value);
1241   }
1242 }
1243 
1244 // As an addition to onnx shape inference, this function leverages constant
1245 // folding and a per-Op based process to update rank/shape for the graph, also
1246 // it update ConstantValueMap accordingly.
ComputeConstant(Node * n,int opset_version)1247 void ComputeConstant(Node* n, int opset_version) {
1248   if (n->kind() == ::c10::onnx::Constant) {
1249     if (n->kindOf(attr::value) == AttributeKind::t) {
1250       at::Tensor const_val = n->t(attr::value);
1251       at::Tensor const_val_copy =
1252           at::empty(const_val.sizes(), const_val.options());
1253       const_val_copy.copy_(const_val);
1254       ConstantValueMap::SetValue(n->output()->debugName(), const_val_copy);
1255     }
1256     return;
1257   }
1258   auto only_rank_available = false;
1259   size_t rank = 0;
1260 
1261   // Constant folding.
1262   auto const_fold_val = ComputeConstantFolding(n, opset_version);
1263   if (const_fold_val.has_value()) {
1264     at::Tensor const_fold_val_copy = at::empty(
1265         const_fold_val.value().sizes(), const_fold_val.value().options());
1266     const_fold_val_copy.copy_(const_fold_val.value());
1267     ConstantValueMap::SetValue(n->output()->debugName(), const_fold_val_copy);
1268     UpdateShapeFromVector(n->output(), const_fold_val_copy.sizes().vec());
1269     return;
1270   }
1271 
1272   switch (n->kind()) {
1273     case ::c10::onnx::Add:
1274     case ::c10::onnx::Div:
1275     case ::c10::onnx::Equal:
1276     case ::c10::onnx::Greater:
1277     case ::c10::onnx::GreaterOrEqual:
1278     case ::c10::onnx::Less:
1279     case ::c10::onnx::LessOrEqual:
1280     case ::c10::onnx::Mod:
1281     case ::c10::onnx::Mul:
1282     case ::c10::onnx::Pow:
1283     case ::c10::onnx::Sub: {
1284       ProcessBroadcastNode(n);
1285       break;
1286     }
1287     case ::c10::onnx::Shape: {
1288       auto input_shape =
1289           ConstantValueMap::GetShapeInto1DInt64Vector(n->input()->debugName());
1290       if (input_shape.has_value()) {
1291         auto shape_value = input_shape.value();
1292         // TODO: getDevice() ?
1293         auto options = c10::TensorOptions().dtype(at::kLong).device(at::kCPU);
1294         auto shape_value_size = static_cast<int64_t>(shape_value.size());
1295         auto f =
1296             at::from_blob(shape_value.data(), {shape_value_size}, at::kLong)
1297                 .to(at::kCPU);
1298         // Need copy here
1299         at::Tensor f_copy = at::empty({shape_value_size}, options);
1300         f_copy.copy_(f);
1301         ConstantValueMap::SetValue(n->output()->debugName(), f_copy);
1302         std::vector<::c10::ShapeSymbol> final_shape_vector(
1303             1, c10::ShapeSymbol::fromStaticSize(shape_value_size));
1304         ::c10::SymbolicShape final_shape(final_shape_vector);
1305         UpdateShape(n->output(), final_shape);
1306       }
1307       break;
1308     }
1309     case ::c10::onnx::Reshape: {
1310       ProcessReshapeNode(n, opset_version);
1311       break;
1312     }
1313     case ::c10::onnx::Transpose: {
1314       if (n->hasAttributeS("perm")) {
1315         auto perm_v = n->is(attr::perm);
1316         rank = perm_v.size();
1317         auto is_default_perm = false;
1318         if (rank == 2 && perm_v[0] == 1 && perm_v[1] == 0) {
1319           is_default_perm = true;
1320         }
1321         auto shape_updated = false;
1322         if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
1323           auto shape_size_0 =
1324               ConstantValueMap::GetShape(n->input(0)->debugName())
1325                   .value()
1326                   .sizes();
1327           if (shape_size_0.has_value()) {
1328             auto shape_vector_0 = shape_size_0.value();
1329             std::vector<::c10::ShapeSymbol> final_shape_vector(
1330                 shape_vector_0.size(), ::c10::ShapeSymbol());
1331             if (is_default_perm) {
1332               std::reverse_copy(
1333                   std::begin(shape_vector_0),
1334                   std::end(shape_vector_0),
1335                   std::begin(final_shape_vector));
1336             } else {
1337               for (const auto i : c10::irange(shape_vector_0.size())) {
1338                 final_shape_vector[i] = shape_vector_0[perm_v[i]];
1339               }
1340             }
1341             ::c10::SymbolicShape final_shape(final_shape_vector);
1342             UpdateShape(n->output(), final_shape);
1343             shape_updated = true;
1344           }
1345         }
1346         if (!shape_updated) {
1347           if (!is_default_perm) {
1348             only_rank_available = true;
1349           } else if (ConstantValueMap::HasRank(n->input(0)->debugName())) {
1350             rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value();
1351             only_rank_available = true;
1352           }
1353         }
1354       }
1355       break;
1356     }
1357     case ::c10::onnx::Concat: {
1358       ProcessConcatNode(n);
1359       break;
1360     }
1361     case ::c10::onnx::ConstantOfShape: {
1362       if (ConstantValueMap::HasValue(n->input()->debugName())) {
1363         auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
1364             n->input()->debugName());
1365         UpdateShapeFromVector(n->output(), shape_temp);
1366         if (!shape_temp.empty()) {
1367           if (n->hasAttributeS("value")) {
1368             auto value = n->t(attr::value).repeat(shape_temp);
1369             ConstantValueMap::SetValue(n->output()->debugName(), value);
1370           } else {
1371             auto options =
1372                 c10::TensorOptions().dtype(at::kFloat).device(at::kCPU);
1373             auto value = at::full({1}, 0.0, options).repeat(shape_temp);
1374             ConstantValueMap::SetValue(n->output()->debugName(), value);
1375           }
1376         }
1377       }
1378       break;
1379     }
1380     case ::c10::onnx::Expand: {
1381       if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
1382         auto input0_shape_size =
1383             ConstantValueMap::GetShape(n->input(0)->debugName())
1384                 .value()
1385                 .sizes();
1386         if (input0_shape_size.has_value()) {
1387           auto input0_shape_value = input0_shape_size.value();
1388           if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
1389             // When value of `shape` is statically known,
1390             // output shape can be computed.
1391             auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
1392                 n->input(1)->debugName());
1393             auto final_shape =
1394                 ComputeShapeFromExpand(input0_shape_value, shape_temp);
1395             if (final_shape.has_value()) {
1396               UpdateShape(n->output(), final_shape.value());
1397             }
1398           } else if (
1399               auto expand_shape =
1400                   ConstantValueMap::GetShapeInto1DInt64VectorWithOneUnknown(
1401                       n->input(1)->debugName())) {
1402             // When shape of `shape` is statically known,
1403             // output rank can be computed.
1404             TORCH_INTERNAL_ASSERT(
1405                 expand_shape.value().size() == 1,
1406                 "`Shape` input to `Expand` should be a 1-D tensor. Instead got rank ",
1407                 expand_shape.value().size());
1408             if (expand_shape.value()[0] > 0) {
1409               std::vector<c10::ShapeSymbol> final_shape;
1410               std::generate_n(
1411                   std::back_inserter(final_shape),
1412                   expand_shape.value()[0],
1413                   ::c10::ShapeSymbol::newSymbol);
1414               UpdateShape(n->output(), c10::SymbolicShape(final_shape));
1415             }
1416           }
1417         }
1418       }
1419       break;
1420     }
1421     case ::c10::onnx::NonZero: {
1422       if (ConstantValueMap::HasRank(n->input()->debugName())) {
1423         auto rank = ConstantValueMap::GetRank(n->input()->debugName()).value();
1424         std::vector<c10::ShapeSymbol> dims;
1425         dims.emplace_back(
1426             c10::ShapeSymbol::fromStaticSize(static_cast<int64_t>(rank)));
1427         auto input_node = n->input()->node();
1428         if (input_node->kind() == ::c10::onnx::ConstantOfShape) {
1429           if (input_node->hasAttributeS("value")) {
1430             auto value =
1431                 input_node->t(attr::value).toType(at::ScalarType::Float);
1432             auto value_a = value.accessor<float, 1>();
1433             if (value_a.size(0) == 1 && std::abs(value_a[0]) > 1e-6) {
1434               if (ConstantValueMap::HasShape(n->input()->debugName())) {
1435                 auto shape_size_0 =
1436                     ConstantValueMap::GetShape(n->input()->debugName()).value();
1437                 if (shape_size_0.isComplete()) {
1438                   auto shape_vector_0 = shape_size_0.sizes().value();
1439                   int64_t num_elements = 1;
1440                   for (auto cur_dim : shape_vector_0) {
1441                     num_elements *= cur_dim.static_size();
1442                   }
1443                   dims.emplace_back(c10::ShapeSymbol::fromStaticSize(
1444                       static_cast<int64_t>(num_elements)));
1445                 }
1446               }
1447             }
1448           }
1449         }
1450         if (dims.size() == 1) {
1451           dims.emplace_back(c10::ShapeSymbol::newSymbol());
1452         }
1453         c10::SymbolicShape shape_v(dims);
1454         UpdateShape(n->output(), shape_v);
1455       }
1456       break;
1457     }
1458     case ::c10::onnx::MatMul: {
1459       ProcessMatMulNode(n);
1460       break;
1461     }
1462     case ::c10::onnx::ReduceMean:
1463     case ::c10::onnx::ReduceProd: {
1464       ProcessReduceNode(n);
1465       break;
1466     }
1467     case ::c10::onnx::RNN:
1468     case ::c10::onnx::LSTM:
1469     case ::c10::onnx::GRU: {
1470       ProcessTimeSeriesNode(n);
1471       break;
1472     }
1473     case ::c10::onnx::Size: {
1474       if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
1475         auto input0_shape_size =
1476             ConstantValueMap::GetShape(n->input(0)->debugName())
1477                 .value()
1478                 .sizes();
1479         if (input0_shape_size.has_value()) {
1480           auto input0_shape_value = input0_shape_size.value();
1481           int64_t total_size = 1;
1482           auto is_full_static = true;
1483           for (const auto i : c10::irange(input0_shape_value.size())) {
1484             if (input0_shape_value[i].is_static()) {
1485               total_size *= input0_shape_value[i].static_size();
1486             } else {
1487               is_full_static = false;
1488               break;
1489             }
1490           }
1491           if (is_full_static) {
1492             auto f_final = onnx_constant_fold::IntToTensor(total_size);
1493             ConstantValueMap::SetValue(n->output(0)->debugName(), f_final);
1494           }
1495         }
1496       }
1497       break;
1498     }
1499     case ::c10::onnx::Slice: {
1500       ProcessSliceNode(n, opset_version);
1501       break;
1502     }
1503     case ::c10::onnx::Cast:
1504     case ::c10::onnx::Relu:
1505     case ::c10::onnx::Softmax: {
1506       ProcessUnchangeNode(n);
1507       break;
1508     }
1509     case ::c10::onnx::Tile: {
1510       if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
1511         auto input0_shape_size =
1512             ConstantValueMap::GetShape(n->input(0)->debugName())
1513                 .value()
1514                 .sizes();
1515         if (input0_shape_size.has_value()) {
1516           auto input0_shape_value = input0_shape_size.value();
1517           if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
1518             auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
1519                 n->input(1)->debugName());
1520             auto final_shape =
1521                 ComputeShapeFromTile(input0_shape_value, shape_temp);
1522             if (final_shape.has_value()) {
1523               UpdateShape(n->output(), final_shape.value());
1524             }
1525           }
1526         }
1527       }
1528       break;
1529     }
1530     case ::c10::onnx::Unsqueeze: {
1531       ProcessUnsqueezeNode(n);
1532       break;
1533     }
1534     default: {
1535       break;
1536     }
1537   }
1538   if (n->outputs().size() > 1 ||
1539       ConstantValueMap::HasShape(n->output(0)->debugName())) {
1540     return;
1541   }
1542   if (only_rank_available) {
1543     UpdateRank(n->output(), rank);
1544   }
1545 }
1546 
IsListConstructIntType(const Value * v)1547 bool IsListConstructIntType(const Value* v) {
1548   if (v->node()->kind() == prim::ListConstruct) {
1549     auto listType = v->node()->output()->type();
1550     auto containedType = listType->containedTypes().at(0);
1551     if (containedType == IntType::get()) {
1552       return true;
1553     }
1554   }
1555   return false;
1556 }
1557 
1558 // Check if all graph inputs are static and allow a cached value to return.
1559 // Since this traverses all inputs of the graph (including weights), it can be
1560 // costly for large graphs. Since this is called for each node in an export,
1561 // and the inputs remain unchanged, we can cut down export time by caching.
AllGraphInputsStaticWithCaching(const Graph * g)1562 bool AllGraphInputsStaticWithCaching(const Graph* g) {
1563   auto maybe_is_static = ConstantValueMap::GetAllGraphInputsStatic();
1564   if (maybe_is_static.has_value()) {
1565     return maybe_is_static.value();
1566   } else {
1567     bool ret = AllGraphInputsStatic(g);
1568     ConstantValueMap::SetAllGraphInputsStatic(ret);
1569     return ret;
1570   }
1571 }
1572 
ProcessConstantValueMap(Node * n,int opset_version)1573 void ProcessConstantValueMap(Node* n, int opset_version) {
1574   // Update ConstantValueMap on node outputs from onnx shape inference
1575   // For outputs, only update static shapes. For input, we update symbolic
1576   // shapes also. ONNX If can have different types on different branches, skip
1577   // here.
1578 
1579   // Update the shape reliability for each node before processing
1580   // ConstantValueMap to prevent unreliable nodes from producing static
1581   // shapes
1582   UpdateReliable(n);
1583 
1584   auto static_input_shape = AllGraphInputsStaticWithCaching(n->owningGraph());
1585   for (auto i : c10::irange(n->outputs().size())) {
1586     if (TensorTypePtr output_type = n->output(i)->type()->cast<TensorType>()) {
1587       if (output_type->dim().has_value()) {
1588         size_t rank = static_cast<size_t>(output_type->dim().value());
1589         ConstantValueMap::SetRank(n->output(i)->debugName(), rank);
1590         auto shape = output_type->symbolic_sizes();
1591         if (shape.isComplete()) {
1592           UpdateShape(n->output(i), shape);
1593         }
1594       }
1595     }
1596   }
1597   // Update ConstantValueMap on node inputs from onnx shape inference.
1598   // ListConstruct is handled here (we only consider IntType, not TensorType) ,
1599   // no need to have a per-op based process.
1600   for (auto i : c10::irange(n->inputs().size())) {
1601     if (TensorTypePtr input_type = n->input(i)->type()->cast<TensorType>()) {
1602       if (input_type->dim().has_value()) {
1603         size_t rank = static_cast<size_t>(input_type->dim().value());
1604         ConstantValueMap::SetRank(n->input(i)->debugName(), rank);
1605         // Only update shape if the input is onnx node.
1606         // If it is aten operators, for example,
1607         //   Float(20, 20, strides=[1, 0], requires_grad=0, device=cpu),
1608         //     %399 : Float(20, 20, strides=[0, 1], requires_grad=0, device=cpu)
1609         //     = prim::ListUnpack(%397)
1610         // The tracer shape may not be correct when dynamic_axes is enabled.
1611         if (n->input(i)->node()->kind().is_onnx() || static_input_shape) {
1612           auto shape = input_type->symbolic_sizes();
1613           if (!ConstantValueMap::HasShape(n->input(i)->debugName())) {
1614             UpdateShape(n->input(i), shape);
1615           }
1616         }
1617       }
1618     } else if (IsListConstructIntType(n->input(i))) {
1619       auto lc_node = n->input(i)->node();
1620       auto rank = lc_node->inputs().size();
1621       auto lc_vector_optional = GetValueFromListConstructNode(lc_node);
1622       if (lc_vector_optional.has_value()) {
1623         auto lc_vector = lc_vector_optional.value();
1624         auto options = c10::TensorOptions().dtype(at::kLong).device(at::kCPU);
1625         auto lc_vector_size = static_cast<int64_t>(lc_vector.size());
1626         auto f = at::from_blob(lc_vector.data(), {lc_vector_size}, at::kLong)
1627                      .to(at::kCPU);
1628         // Need copy here
1629         at::Tensor f_copy = at::empty({lc_vector_size}, options);
1630         f_copy.copy_(f);
1631         ConstantValueMap::SetValue(n->input(i)->debugName(), f_copy);
1632         UpdateShapeFromVector(n->input(i), {lc_vector_size});
1633       } else {
1634         UpdateShapeFromVector(n->input(i), {static_cast<int64_t>(rank)});
1635       }
1636       SetShapeValueFromListConstructNode(lc_node);
1637     }
1638   }
1639   // Additional logic to update the graph and ConstantValueMap
1640   ComputeConstant(n, opset_version);
1641 }
1642 
1643 // Any additional post process that are specific to individual node kind.
SpecialPostProcess(Node * n)1644 void SpecialPostProcess(Node* n) {
1645   switch (n->kind()) {
1646     case ::c10::onnx::SequenceInsert: {
1647       // Special case when input sequence to SequenceInsert is empty.
1648       // onnx Sequence type requires element type to be set.
1649       // If the list to insert is empty, we set the elem type by
1650       // looking at the tensor being inserted.
1651       auto seq_node = n->input(0)->node();
1652       auto t_type = n->input(1)->type()->cast<TensorType>();
1653 
1654       auto update_sequence_empty_dtype = [](Node* n,
1655                                             const TensorTypePtr& t_type) {
1656         TORCH_INTERNAL_ASSERT(n && n->kind() == ::c10::onnx::SequenceEmpty);
1657         TORCH_INTERNAL_ASSERT(t_type && t_type->scalarType().has_value());
1658         auto scalar_type = t_type->scalarType().value();
1659         auto onnx_type = ATenTypeToOnnxType(scalar_type);
1660         n->i_(attr::dtype, onnx_type);
1661         n->output()->setType(ListType::create(t_type));
1662       };
1663 
1664       auto find_sequence_empty = [](Value* input,
1665                                     TensorTypePtr t_type) -> Node* {
1666         auto find_sequence_empty_impl =
1667             [](Value* input,
1668                TensorTypePtr t_type,
1669                auto& find_sequence_empty_ref) -> Node* {
1670           auto input_node = input->node();
1671           TORCH_INTERNAL_ASSERT(input_node);
1672 
1673           // 1. Input is from SequenceEmpty.
1674           if (input_node->kind() == ::c10::onnx::SequenceEmpty) {
1675             return input_node;
1676           }
1677 
1678           // 2. Input is subblock input of a Loop node, which takes outer block
1679           // SequenceEmpty as input.
1680           if (input_node->kind() == prim::Param) {
1681             auto loop_n = input_node->owningBlock()->owningNode();
1682             if (nullptr == loop_n || loop_n->kind() != ::c10::onnx::Loop) {
1683               return nullptr;
1684             }
1685 
1686             auto it = std::find(
1687                 input_node->outputs().begin(),
1688                 input_node->outputs().end(),
1689                 input);
1690             auto idx = std::distance(input_node->outputs().begin(), it);
1691 
1692             auto outer_block_node = loop_n->input(idx)->node();
1693             if (outer_block_node &&
1694                 outer_block_node->kind() == ::c10::onnx::SequenceEmpty) {
1695               // Found SequenceEmpty
1696               input->setType(ListType::create(t_type));
1697               return outer_block_node;
1698             } else {
1699               // Outer block node still not SequenceEmpty, call recursively in
1700               // case of nested loop.
1701               auto found_n = find_sequence_empty_ref(
1702                   loop_n->input(idx), t_type, find_sequence_empty_ref);
1703               if (found_n) {
1704                 input->setType(ListType::create(t_type));
1705               }
1706               return found_n;
1707             }
1708           }
1709 
1710           // Could not find source SequenceEmpty node.
1711           return nullptr;
1712         };
1713         return find_sequence_empty_impl(
1714             input, std::move(t_type), find_sequence_empty_impl);
1715       };
1716 
1717       if (seq_node && t_type && t_type->scalarType()) {
1718         if (seq_node->kind() == ::c10::onnx::SequenceEmpty) {
1719           update_sequence_empty_dtype(seq_node, t_type);
1720         } else if (seq_node->kind() == prim::Param) {
1721           // Try to find original onnx::SequenceEmpty node in outer block.
1722           auto seq_empty_n = find_sequence_empty(n->input(0), t_type);
1723           if (seq_empty_n) {
1724             update_sequence_empty_dtype(seq_empty_n, t_type);
1725           }
1726         }
1727         n->output()->setType(ListType::create(t_type));
1728       }
1729 
1730       break;
1731     }
1732     case ::c10::onnx::Cast: {
1733       // ONNX shape inference is not able to assign output tensor shape,
1734       // when input to onnx::Cast has incomplete tensor shape, for example
1735       // missing shape, rank, dtype, etc. This postprocess sets the correct
1736       // dtype for output tensor, since the dtype info is stored in Cast
1737       // attribute.
1738       TensorTypePtr t_type = n->output()->type()->cast<TensorType>();
1739       if (nullptr != t_type && !t_type->scalarType().has_value()) {
1740         auto onnx_dtype = n->i(attr::to);
1741         auto aten_dtype = ONNXTypeToATenType(onnx_dtype);
1742         n->output()->setType(t_type->withScalarType(aten_dtype));
1743       }
1744       break;
1745     }
1746     case ::c10::onnx::ConstantOfShape: {
1747       // ONNX shape inference is not able to propagate output tensor shape
1748       // for onnx::ConstantOfShape if input `shape` is not constant.
1749       // This is a temporary solution when some partial information is
1750       // available, for example, knowing rank of output tensor, or knowing
1751       // symbolic shape. This solution won't be needed once we have proper
1752       // symbolic propagation.
1753       auto shape_node = n->input(0)->node();
1754       if (shape_node->kind() == ::c10::onnx::Shape) {
1755         // Shape -> ConstantOfShape
1756         auto orig_type = shape_node->input()->type()->cast<TensorType>();
1757         auto v_type = n->output()->type()->cast<TensorType>();
1758         if (v_type && !v_type->sizes().concrete_sizes()) {
1759           if (orig_type && orig_type->dim()) {
1760             // Assign symbolic shape of original input of onnx::Shape.
1761             v_type = v_type->withSymbolicShapes(orig_type->symbolic_sizes());
1762             n->output()->setType(v_type);
1763           } else if (
1764               shape_node->input()->node()->kind() ==
1765               ::c10::prim::ListConstruct) {
1766             // Assign rank of original input of onnx::Shape.
1767             v_type = v_type->withSizes({static_cast<int64_t>(
1768                 shape_node->input()->node()->inputs().size())});
1769             n->output()->setType(v_type);
1770           }
1771         }
1772       } else if (shape_node->kind() == ::c10::prim::ListConstruct) {
1773         // ListConstruct -> ConstantOfShape
1774         auto v_type = n->output()->type()->cast<TensorType>();
1775         if (v_type && !v_type->sizes().concrete_sizes()) {
1776           auto value = n->t(attr::value);
1777           v_type = v_type->withScalarType(value.scalar_type());
1778           std::vector<c10::ShapeSymbol> sizes(
1779               shape_node->inputs().size(), c10::ShapeSymbol::newSymbol());
1780           v_type = v_type->withSymbolicShapes(c10::SymbolicShape(sizes));
1781           n->output()->setType(v_type);
1782         }
1783       }
1784       break;
1785     }
1786     case ::c10::onnx::If: {
1787       if (!IsValidONNXControlflowNode(n)) {
1788         break;
1789       }
1790       FixupONNXControlflowNodeOutputs(n);
1791       break;
1792     }
1793     case ::c10::onnx::Loop: {
1794       if (!IsValidONNXControlflowNode(n)) {
1795         break;
1796       }
1797       FixupONNXControlflowNodeOutputs(n);
1798       break;
1799     }
1800   }
1801 }
1802 
UpdateOutputTypeByONNXProto(Node * n,Node * clone_node,const onnx::ModelProto & model_proto,SymbolDimMap & symbol_dim_map,DimSymbolMap & dim_symbol_map)1803 void UpdateOutputTypeByONNXProto(
1804     Node* n,
1805     Node* clone_node,
1806     const onnx::ModelProto& model_proto,
1807     SymbolDimMap& symbol_dim_map,
1808     DimSymbolMap& dim_symbol_map) {
1809   const auto& graph_proto = model_proto.graph();
1810 
1811   // get data from value_info and updated original graph.
1812   const auto updateNodeOutputsByONNXValueInfo =
1813       [&](const onnx::ValueInfoProto& v_info) {
1814         for (size_t i = 0; i < n->outputs().size(); ++i) {
1815           if (clone_node->output(i)->debugName() == v_info.name()) {
1816             UpdateTorchValueByOnnxValueInfo(
1817                 n->output(i), v_info, symbol_dim_map, dim_symbol_map);
1818           }
1819         }
1820       };
1821 
1822   // Check graph outputs for inferred shapes.
1823   for (const auto i : c10::irange(graph_proto.output_size())) {
1824     updateNodeOutputsByONNXValueInfo(graph_proto.output(i));
1825   }
1826 
1827   // Check value_infos for inferred shapes.
1828   for (const auto i : c10::irange(graph_proto.value_info_size())) {
1829     updateNodeOutputsByONNXValueInfo(graph_proto.value_info(i));
1830   }
1831 }
1832 
FetchBlockInputMetadataFromParent(Block * b)1833 void FetchBlockInputMetadataFromParent(Block* b) {
1834   auto n = b->owningNode();
1835   if (nullptr != n && n->kind() == ::c10::onnx::Loop) {
1836     // Copy node input metadata to subgraph input.
1837     for (size_t i = 0; i < n->inputs().size(); ++i) {
1838       b->inputs().at(i)->setType(n->inputs().at(i)->type());
1839     }
1840   }
1841 }
1842 
RemoveProcessedInputs(const Node * n)1843 void RemoveProcessedInputs(const Node* n) {
1844   // After processing a node for shape inference, remove intermediate tensors
1845   // that are stored in ConstantValueMap to reduce memory usage.
1846   // This will only remove tensors that are no longer needed by any other node.
1847 
1848   // Returns whether a node was already processed for shape inference.
1849   const auto isNodeProcessed = [](const Node* node) {
1850     const auto& outputs = node->outputs();
1851     return std::any_of(outputs.begin(), outputs.end(), [](const Value* output) {
1852       // Assumes shape inference can at least determine the rank of the outputs.
1853       // If this assumption is wrong, some intermediate tensors will only be
1854       // deleted once shape inference is completed for the entire graph.
1855       return ConstantValueMap::HasRank(output->debugName());
1856     });
1857   };
1858 
1859   // An input value is no longer needed if all of its consumer nodes
1860   // have already been processed.
1861   const auto isValueNoLongerNeeded = [isNodeProcessed](const Value* input) {
1862     const auto& uses = input->uses();
1863     return std::all_of(
1864         uses.begin(), uses.end(), [isNodeProcessed](const Use& use) {
1865           return isNodeProcessed(use.user);
1866         });
1867   };
1868 
1869   for (const auto* input : n->inputs()) {
1870     if (ConstantValueMap::HasValue(input->debugName()) &&
1871         isValueNoLongerNeeded(input)) {
1872       ConstantValueMap::EraseValue(input->debugName());
1873     }
1874   }
1875 }
1876 
ONNXShapeTypeInference(Block * b,const ParamMap & params_dict,int opset_version)1877 void ONNXShapeTypeInference(
1878     Block* b,
1879     const ParamMap& params_dict,
1880     int opset_version) {
1881   FetchBlockInputMetadataFromParent(b);
1882   auto valsToParamsMap = buildValueToParamsMap(b, params_dict);
1883   for (auto const& it : valsToParamsMap) {
1884     auto key = it.first;
1885     auto value = it.second;
1886     if (key->node()->kind() == prim::Param) {
1887       if (value.second.isTensor()) {
1888         ConstantValueMap::SetValue(value.first, value.second.toTensor());
1889       }
1890     } else if (key->node()->kind() == ::c10::onnx::Constant) {
1891       at::Tensor const_val = key->node()->t(attr::value);
1892       at::Tensor const_val_copy =
1893           at::empty(const_val.sizes(), const_val.options());
1894       const_val_copy.copy_(const_val);
1895       ConstantValueMap::SetValue(value.first, const_val_copy);
1896     } else {
1897       throw std::runtime_error(
1898           "ONNXShapeTypeInference - Unsupported kind of constant node found.");
1899     }
1900   }
1901   for (auto n : b->nodes()) {
1902     for (auto subblock : n->blocks()) {
1903       ONNXShapeTypeInference(subblock, params_dict, opset_version);
1904     }
1905     ONNXShapeTypeInference(n, params_dict, opset_version);
1906     RemoveProcessedInputs(n);
1907   }
1908 }
1909 
1910 } // namespace
1911 
1912 // For some operators, there are some inputs not related to shape inference.
1913 // For example, LSTM input 4 (sequence_lens) is optional,
1914 // and the shape inference can be done through other required inputs.
1915 // When we compute reliable, we don't need this input be reliable.
1916 static std::unordered_map<std::string, std::unordered_set<int64_t>>
1917     non_required_shape_inference_idx_map = {{"onnx::LSTM", {4}}};
1918 
AllGraphInputsStatic(const Graph * g)1919 bool AllGraphInputsStatic(const Graph* g) {
1920   for (auto n : g->inputs()) {
1921     if (TensorTypePtr input_type = n->type()->cast<TensorType>()) {
1922       if (input_type->dim()) {
1923         auto shape = input_type->symbolic_sizes();
1924         if (!ConstantValueMap::HasShape(n->debugName())) {
1925           UpdateShapeConstantValueMap(n, shape);
1926         }
1927       }
1928     }
1929   }
1930   for (auto n : g->inputs()) {
1931     // Some inputs can be non-Tensor type, e.g.,
1932     // __torch__.torch.classes.quantized.LinearPackedParamsBase
1933     // so we only need check Tensor type here.
1934     if (n->type()->cast<TensorType>() && !n->isCompleteTensor()) {
1935       return false;
1936     }
1937   }
1938   return true;
1939 }
1940 
AreInputsReliableOrStatic(Node * n)1941 std::pair<bool, bool> AreInputsReliableOrStatic(Node* n) {
1942   auto reliable = true;
1943   auto complete = true;
1944   auto input_size = n->inputs().size();
1945   std::unordered_set<int64_t> non_required_idx = {};
1946   if (non_required_shape_inference_idx_map.find(n->kind().toDisplayString()) !=
1947       non_required_shape_inference_idx_map.end()) {
1948     non_required_idx =
1949         non_required_shape_inference_idx_map[n->kind().toDisplayString()];
1950   }
1951   for (auto idx : c10::irange(input_size)) {
1952     if (!non_required_idx.empty() &&
1953         non_required_idx.find(idx) != non_required_idx.end()) {
1954       continue;
1955     }
1956     auto input = n->inputs()[idx];
1957     // Always consider None reliable and complete, because it represents
1958     // unspecified optional inputs in ONNX.
1959     if (input->node()->mustBeNone()) {
1960       continue;
1961     }
1962     reliable &=
1963         ConstantValueMap::GetTypeReliable(input->debugName()).value_or(false);
1964     if (auto pt = input->type()->cast<TensorType>()) {
1965       if (!pt->sizes().isComplete()) {
1966         complete = false;
1967       }
1968     }
1969   }
1970   return std::make_pair(reliable, complete);
1971 }
1972 
1973 // There is no need to put onnx type here, but we need this
1974 // for some legacy tests when onnx_shape_inference=False.
1975 static std::unordered_set<std::string> nodeTypeReliableForTracer = {
1976     "prim::ListConstruct",
1977     "onnx::Cast",
1978     "onnx::Constant",
1979     "onnx::Relu",
1980     "com.microsoft::Gelu",
1981     "aten::ATen"};
1982 
UpdateReliable(torch::jit::Value * output,const std::pair<bool,bool> & inferred_type_reliable,bool no_type_warning)1983 void UpdateReliable(
1984     torch::jit::Value* output,
1985     const std::pair<bool, bool>& inferred_type_reliable,
1986     bool no_type_warning) {
1987   auto inferred =
1988       ConstantValueMap::GetUseInferredType(output->debugName()).value_or(false);
1989   auto isTypeReliableForTracer =
1990       nodeTypeReliableForTracer.find(
1991           output->node()->kind().toDisplayString()) !=
1992       nodeTypeReliableForTracer.end();
1993   if (!inferred && !isTypeReliableForTracer &&
1994       !output->node()->kind().is_onnx() && no_type_warning) {
1995     TORCH_WARN(
1996         "The shape inference of ",
1997         output->node()->kind().toDisplayString(),
1998         " type is missing, so it may result in wrong shape inference for the exported graph. ",
1999         "Please consider adding it in symbolic function.");
2000     // Experimental, nothing sent to stdout nor stderr.
2001     diagnostics::Diagnose(
2002         diagnostics::Rule::kNodeMissingOnnxShapeInference,
2003         diagnostics::Level::kWarning,
2004         {{"op_name", output->node()->kind().toDisplayString()}});
2005   }
2006   auto reliable = false;
2007   if (inferred) {
2008     reliable = inferred_type_reliable.first;
2009   } else {
2010     if (inferred_type_reliable.second && isTypeReliableForTracer) {
2011       reliable = true;
2012     }
2013   }
2014   // Assume that the tracer can estimate rank correctly,
2015   // then the output tensor of Shape should always be reliable.
2016   if (output->node()->kind() == ::c10::onnx::Shape) {
2017     reliable = true;
2018   }
2019   ConstantValueMap::SetTypeReliable(output->debugName(), reliable);
2020   if (!reliable) {
2021     if (auto output_tensor_type = output->type()->cast<TensorType>()) {
2022       output->setType(output_tensor_type->withSymbolicShapes(
2023           ::c10::SymbolicShape(output_tensor_type->dim())));
2024     }
2025   }
2026 }
2027 
UpdateReliable(Node * n)2028 void UpdateReliable(Node* n) {
2029   auto input_reliable = AreInputsReliableOrStatic(n);
2030   for (auto output : n->outputs()) {
2031     UpdateReliable(output, input_reliable);
2032   }
2033 }
2034 
2035 // Traverse the graph inputs and compute reliability (e.g., are shapes static).
2036 // Since the inputs do not change during export, we save computation time by
2037 // marking it as computed and subsequently skipping.
SetGraphInputTypeReliable(const Graph * g)2038 void SetGraphInputTypeReliable(const Graph* g) {
2039   if (!ConstantValueMap::GetAllGraphInputsReliableComputed()) {
2040     for (auto graph_input : g->inputs()) {
2041       if (!ConstantValueMap::HasTypeReliable(graph_input->debugName())) {
2042         ConstantValueMap::SetTypeReliable(graph_input->debugName(), true);
2043       }
2044     }
2045     ConstantValueMap::SetAllGraphInputsReliableComputed(true);
2046   }
2047 }
2048 
ONNXShapeTypeInference(Node * n,const ParamMap & params_dict,int opset_version)2049 void ONNXShapeTypeInference(
2050     Node* n,
2051     const ParamMap& params_dict,
2052     int opset_version) {
2053   std::unordered_map<std::string, std::string> torch_to_onnx_input;
2054   std::unordered_map<std::string, std::string> torch_to_onnx_output;
2055   auto& original_shape_data = ConstantValueMap::GetInferredShapeData();
2056   ShapeDataMap inferred_shape_data;
2057   auto& symbol_dim_map = ConstantValueMap::GetSymbolDimMap();
2058   auto& dim_symbol_map = ConstantValueMap::GetDimSymbolMap();
2059 
2060   SetGraphInputTypeReliable(n->owningGraph());
2061   GRAPH_UPDATE(
2062       "Running ONNX shape inference for node: ", n->kind().toDisplayString());
2063 
2064   if (IsValidONNXNode(n)) {
2065     // Create a Graph containing only the single node n.
2066     // This graph is later converted to ONNX to run shape inference.
2067     auto n_graph = std::make_shared<Graph>();
2068     auto clone_node = CloneNodeToGraph(n, n_graph, params_dict, opset_version);
2069     n_graph->insertNode(clone_node);
2070 
2071     // Register all node outputs as graph outputs.
2072     for (auto output : clone_node->outputs()) {
2073       n_graph->registerOutput(output);
2074     }
2075 
2076     // Map original PyTorch graph's i/o name
2077     // to temporal ONNX graph's i/o name for shape inference
2078     for (size_t i = 0; i < clone_node->inputs().size(); ++i) {
2079       torch_to_onnx_input[n->input(i)->debugName()] =
2080           clone_node->input(i)->debugName();
2081     }
2082 
2083     for (size_t i = 0; i < clone_node->outputs().size(); ++i) {
2084       torch_to_onnx_output[n->output(i)->debugName()] =
2085           clone_node->output(i)->debugName();
2086     }
2087     // Make inferred_shape_data use name from temporal ONNX graph
2088     // instead of original PyTorch graph. Only copy what we need,
2089     // which are the inputs of n.
2090     for (auto input : n->inputs()) {
2091       const auto maybe_shape = original_shape_data.find(input->debugName());
2092       if (maybe_shape != original_shape_data.end()) {
2093         const auto onnx_output_name =
2094             torch_to_onnx_input.find(input->debugName());
2095         if (onnx_output_name != torch_to_onnx_input.end()) {
2096           inferred_shape_data[onnx_output_name->second] = maybe_shape->second;
2097         }
2098       }
2099     }
2100     // Use scalar_type_analysis without low precision cast
2101     ScalarTypeAnalysisForONNX(n_graph, false, opset_version);
2102 
2103     GRAPH_DEBUG("Original torch graph: ", n->owningGraph()->toString());
2104     GRAPH_DEBUG(
2105         "Cloned torch graph to run shape inference: ", n_graph->toString());
2106 
2107     if (IsGraphValidForInference(n_graph)) {
2108       // TODO: Some ops have conversion happen at Peephole pass.
2109       //       The conversion here is incomplete for these ops.
2110       //       e.g: ListConstruct, ListUnpack, etc.
2111       std::shared_ptr<onnx::ModelProto> model_proto;
2112       ConvertGraphToONNXProto(
2113           n_graph, model_proto, symbol_dim_map, dim_symbol_map, opset_version);
2114       GRAPH_DEBUG(
2115           "ONNX graph to run shape inference: ", prettyPrint(*model_proto));
2116 
2117       // infer shape
2118       try {
2119         // TODO(#79208): Enable more operators to support data propagation
2120         switch (n->kind()) {
2121           case ::c10::onnx::Shape:
2122           case ::c10::onnx::Gather: {
2123             auto* schema_registry = onnx::OpSchemaRegistry::Instance();
2124             onnx::ShapeInferenceOptions options{
2125                 /*check_type_val=*/false,
2126                 /*strict_mode_val=*/0,
2127                 /*data_prop_val=*/true};
2128             onnx::shape_inference::InferShapes(
2129                 *model_proto, schema_registry, options, &inferred_shape_data);
2130             break;
2131           }
2132           default: {
2133             onnx::shape_inference::InferShapes(*model_proto);
2134             break;
2135           }
2136         }
2137         UpdateOutputTypeByONNXProto(
2138             n, clone_node, *model_proto, symbol_dim_map, dim_symbol_map);
2139       } catch (std::runtime_error& ex) {
2140         // TODO: include this as warning once we have a more consolidated
2141         // warning system.
2142         GRAPH_DEBUG(
2143             "ONNX shape inference fails with: ",
2144             ex.what(),
2145             " on graph: ",
2146             n_graph->toString());
2147         // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2148         const char shape_err[] = "ShapeInferenceError";
2149         // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2150         const char type_err[] = "TypeInferenceError";
2151         if ((strstr(ex.what(), shape_err) == nullptr) &&
2152             (strstr(ex.what(), type_err) == nullptr)) {
2153           throw;
2154         }
2155       }
2156       GRAPH_DEBUG(
2157           "ONNX graph after shape inference: ", prettyPrint(*model_proto));
2158     }
2159   } else if (CustomSettype(n)) {
2160     // If the node is not ONNX standard, go through every output to check if
2161     // they all have shape. If they all do, this should be reliable even if the
2162     // Op is not from ONNX.
2163     for (auto node_output : n->outputs()) {
2164       // Custom setType output should get in here if it's set correctly. They
2165       // will be updated to inferred for later updatereliable function.
2166       ConstantValueMap::SetUseInferredType(node_output->debugName(), true);
2167     }
2168   }
2169 
2170   SpecialPostProcess(n);
2171   // Get data propagation result from ONNX shape inference
2172   for (const auto& output : n->outputs()) {
2173     const auto inferred_shape_pair =
2174         inferred_shape_data.find(torch_to_onnx_output[output->debugName()]);
2175     if (inferred_shape_pair != inferred_shape_data.end()) {
2176       const auto& inferred_shape = inferred_shape_pair->second;
2177       int rank = inferred_shape.dim_size();
2178       std::vector<::c10::ShapeSymbol> final_shape(rank);
2179       for (int i = 0; i < rank; ++i) {
2180         final_shape[i] = ONNXDimToShapeSymbol(
2181             inferred_shape.dim(i), symbol_dim_map, dim_symbol_map);
2182       }
2183       c10::SymbolicShape shape_value(final_shape);
2184       // Store data propagation result into shapeValueMap
2185       ConstantValueMap::SetShapeValue(output->debugName(), shape_value);
2186       // Use original name in PyTorch graph instead of
2187       // temporary name in intermediate ONNX graph
2188       // Add this back to original_shape_data
2189       original_shape_data[output->debugName()] = inferred_shape;
2190     }
2191   }
2192 
2193   if (IsValidONNXNode(n)) {
2194     ProcessConstantValueMap(n, opset_version);
2195     if (n->kind() != prim::ListConstruct) {
2196       for (auto input : n->inputs()) {
2197         if (input->node()->kind() == prim::ListConstruct) {
2198           UpdateReliable(input, AreInputsReliableOrStatic(input->node()));
2199         }
2200       }
2201     }
2202   }
2203   UpdateReliable(n);
2204 
2205   // For the node type that does not have ComputeConstant logic, it may have
2206   // reliable shape but its shape is not in ConstantValueMap. So we need this
2207   // logic to update ConstantValueMap.
2208   for (auto node_output : n->outputs()) {
2209     UpdateShapeConstantIfReliable(node_output);
2210   }
2211 
2212   GRAPH_DEBUG(
2213       "Torch graph after shape inference:", n->owningGraph()->toString());
2214 }
2215 
ONNXSetDynamicInputShape(std::shared_ptr<Graph> & graph,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes,const std::vector<std::string> & input_names)2216 void ONNXSetDynamicInputShape(
2217     std::shared_ptr<Graph>& graph,
2218     const std::unordered_map<
2219         std::string,
2220         std::unordered_map<int64_t, std::string>>& dynamic_axes,
2221     const std::vector<std::string>& input_names) {
2222   GRAPH_UPDATE("ONNX set dynamic input shape.");
2223   GRAPH_UPDATE("dynamic axes tensor names:", [&]() {
2224     std::vector<std::string> res(dynamic_axes.size());
2225     std::transform(
2226         dynamic_axes.begin(), dynamic_axes.end(), res.begin(), [](auto pair) {
2227           return pair.first;
2228         });
2229     return res;
2230   }());
2231 
2232   std::map<std::string, ::c10::ShapeSymbol> name_to_sym;
2233 
2234   for (const auto i : c10::irange(input_names.size())) {
2235     const auto& input_name = input_names[i];
2236     if (dynamic_axes.find(input_name) != dynamic_axes.end()) {
2237       auto axes_names = dynamic_axes.find(input_name)->second;
2238       TORCH_INTERNAL_ASSERT(i < graph->inputs().size());
2239       auto input_tensor_type = graph->inputs()[i]->type()->cast<TensorType>();
2240       if (!input_tensor_type) {
2241         continue;
2242       }
2243 
2244       auto shape_ref = input_tensor_type->symbolic_sizes().sizes();
2245       TORCH_CHECK(
2246           shape_ref.has_value(), "Input tensor shape should have value.");
2247       auto shape = shape_ref.value();
2248 
2249       for (const auto& pair : axes_names) {
2250         const auto axis = pair.first;
2251         const auto name = pair.second;
2252         if (name_to_sym.find(name) == name_to_sym.end()) {
2253           name_to_sym[name] = ::c10::ShapeSymbol::newSymbol();
2254         }
2255         TORCH_CHECK(
2256             axis < static_cast<int64_t>(shape.size()),
2257             "Dynamic shape axis should be no more than the shape dimension for ",
2258             name);
2259         shape[axis] = name_to_sym[name];
2260       }
2261 
2262       graph->inputs()[i]->setType(
2263           input_tensor_type->withSymbolicShapes(::c10::SymbolicShape(shape)));
2264     }
2265   }
2266 }
2267 
HasSequenceTypeOutput(Node * node)2268 bool HasSequenceTypeOutput(Node* node) {
2269   if (node->kind() == ::c10::onnx::SplitToSequence ||
2270       node->kind() == ::c10::onnx::SequenceInsert ||
2271       node->kind() == ::c10::onnx::SequenceEmpty ||
2272       node->kind() == ::c10::onnx::SequenceErase ||
2273       node->kind() == ::c10::onnx::SequenceConstruct ||
2274       node->kind() == ::c10::onnx::Loop || node->kind() == ::c10::onnx::If)
2275     return true;
2276   return false;
2277 }
2278 
ONNXUpdateTypeFromTensor(Value * graph_output,const at::Tensor & output,bool onnx_shape_inference)2279 void ONNXUpdateTypeFromTensor(
2280     Value* graph_output,
2281     const at::Tensor& output,
2282     bool onnx_shape_inference) {
2283   if (onnx_shape_inference) {
2284     MergeInferredTypeAndSetMap(
2285         graph_output, TensorType::create(output), graph_output->type());
2286   } else {
2287     graph_output->inferTypeFrom(output);
2288   }
2289 }
2290 
2291 // Recursively look into elements in `output_obj`, and assign shape/type info
2292 // into flattened graph outputs. `outputs_index` is passed in to point to the
2293 // current index in flattened graph outputs. The updated `outputs_index` is
2294 // returned at the end of the function.
ONNXAssignOutputShape(std::shared_ptr<Graph> & graph,size_t outputs_index,PyObject * output_obj,bool onnx_shape_inference,bool is_script,int opset_version)2295 size_t ONNXAssignOutputShape(
2296     std::shared_ptr<Graph>& graph,
2297     size_t outputs_index,
2298     PyObject* output_obj,
2299     bool onnx_shape_inference,
2300     bool is_script,
2301     int opset_version) {
2302   auto index_check = [&]() {
2303     TORCH_INTERNAL_ASSERT(
2304         outputs_index <= graph->outputs().size(),
2305         "Incorrect number of elements provided as example outputs.");
2306   };
2307 
2308   index_check();
2309 
2310   if (THPVariable_Check(output_obj)) {
2311     const at::Tensor& var = THPVariable_Unpack(output_obj);
2312     ONNXUpdateTypeFromTensor(
2313         graph->outputs().at(outputs_index), var, onnx_shape_inference);
2314     outputs_index++;
2315   } else if (PyTuple_Check(output_obj)) {
2316     size_t tuple_len = PyTuple_GET_SIZE(output_obj);
2317     for (const auto i : c10::irange(tuple_len)) {
2318       outputs_index = ONNXAssignOutputShape(
2319           graph,
2320           outputs_index,
2321           PyTuple_GET_ITEM(output_obj, i),
2322           onnx_shape_inference,
2323           is_script,
2324           opset_version);
2325     }
2326   } else if (PyList_Check(output_obj)) {
2327     const auto list_len = PyList_GET_SIZE(output_obj);
2328     if (HasSequenceTypeOutput(graph->outputs().at(outputs_index)->node())) {
2329       auto output_type = graph->outputs().at(outputs_index)->type();
2330       TORCH_CHECK(
2331           output_type->cast<ListType>(),
2332           "Expected a sequence type, but received a non-iterable type in graph output index ",
2333           outputs_index);
2334       if (list_len > 0) {
2335         auto list_elem = PyList_GET_ITEM(output_obj, 0);
2336         TORCH_INTERNAL_ASSERT(THPVariable_Check(list_elem));
2337         auto& var = THPVariable_Unpack(list_elem);
2338         for (const auto i : c10::irange(1, list_len)) {
2339           list_elem = PyList_GET_ITEM(output_obj, i);
2340           TORCH_INTERNAL_ASSERT(THPVariable_Check(list_elem));
2341           auto& new_var = THPVariable_Unpack(list_elem);
2342           TORCH_CHECK(
2343               var.scalar_type() == new_var.scalar_type(),
2344               "Unsupported sequence with mixed element types in model outputs. "
2345               "ONNX supports only sequences of elements of the same data type.");
2346         }
2347         auto elem_type = graph->outputs()
2348                              .at(outputs_index)
2349                              ->type()
2350                              ->castRaw<ListType>()
2351                              ->getElementType()
2352                              ->cast<TensorType>();
2353         elem_type = elem_type->withScalarType(var.scalar_type());
2354         auto graph_output = graph->outputs().at(outputs_index);
2355         MergeInferredTypeAndSetMap(
2356             graph_output, graph_output->type(), ListType::create(elem_type));
2357       } else {
2358         graph->outputs()
2359             .at(outputs_index)
2360             ->setType(graph->outputs().at(outputs_index)->type());
2361       }
2362       outputs_index++;
2363     } else {
2364       // When torch output is a list type, but ONNX node is not a
2365       // sequence type. Like prim::ListConstruct
2366       for (const auto i : c10::irange(list_len)) {
2367         outputs_index = ONNXAssignOutputShape(
2368             graph,
2369             outputs_index,
2370             PyList_GET_ITEM(output_obj, i),
2371             onnx_shape_inference,
2372             is_script,
2373             opset_version);
2374       }
2375     }
2376   } else if (PyDict_Check(output_obj)) {
2377     // Support for dict data type is limited to fixed size dictionaries in
2378     // ONNX.
2379     // Dictionary values are unrolled and keys are not preserved.
2380     auto* items = PyDict_Items(output_obj);
2381     auto unrolled_dict = py::reinterpret_borrow<py::list>(items);
2382     TORCH_INTERNAL_ASSERT(PyList_Check(unrolled_dict.ptr()));
2383     for (const auto i : c10::irange(unrolled_dict.size())) {
2384       outputs_index = ONNXAssignOutputShape(
2385           graph,
2386           outputs_index,
2387           PyList_GET_ITEM(unrolled_dict.ptr(), i),
2388           onnx_shape_inference,
2389           is_script,
2390           opset_version);
2391     }
2392     Py_DECREF(items);
2393   } else if (THPUtils_checkString(output_obj)) {
2394     // Ignore string, since they are not supported as output in ONNX.
2395   } else if (PyNone_Check(output_obj)) {
2396     // Tracing:
2397     //    Ignore None, since it is not captured in IR graph as output.
2398     // Scripting:
2399     //    Ignore None, if observing a fixed `None` node in IR graph. Because
2400     //    it is meaningless to include it as graph output as it carries no
2401     //    data/information. Plus that static `None` is not supported in ONNX
2402     //    IR. Otherwise, the output should have type `Optional`, and should be
2403     //    converted to ONNX `Optional`.
2404 
2405     // More context:
2406     // Cause: in tracing we flatten the outputs in ONNXTracedModule.forward
2407     // in torch/jit/_trace.py while tracing. This means the traced IR graph
2408     // has None outputs omitted.
2409     // But then the outputs passed in here are un-flattened, which means they
2410     // contain None objects. Ideally we'd remove this difference.
2411     if (is_script && outputs_index < graph->outputs().size()) {
2412       if (graph->outputs().at(outputs_index)->node()->mustBeNone()) {
2413         if (opset_version >= 15) {
2414           ReplaceGraphOutputNoneWithOptional(graph, outputs_index);
2415           outputs_index++;
2416         } else {
2417           graph->eraseOutput(outputs_index);
2418         }
2419       } else {
2420         outputs_index++;
2421       }
2422     }
2423   } else {
2424     std::string msg =
2425         ("Model output has unsupported type. See "
2426          "https://pytorch.org/docs/stable/onnx.html#types. Got type: ");
2427     msg += THPUtils_typename(output_obj);
2428     throw std::runtime_error(msg);
2429   }
2430 
2431   index_check();
2432 
2433   return outputs_index;
2434 }
2435 
ONNXOptionalNodeForNone(std::shared_ptr<Graph> & graph)2436 Node* ONNXOptionalNodeForNone(std::shared_ptr<Graph>& graph) {
2437   TypePtr elem_type = TensorType::get()->withScalarType(at::ScalarType::Float);
2438   Node* opt_node = graph->create(::c10::onnx::Optional, 1);
2439   opt_node->ty_(Symbol::attr("type"), elem_type);
2440   opt_node->output()->setType(OptionalType::create(elem_type));
2441   return opt_node;
2442 }
2443 
ReplaceGraphOutputNoneWithOptional(std::shared_ptr<Graph> & graph,size_t outputs_index)2444 void ReplaceGraphOutputNoneWithOptional(
2445     std::shared_ptr<Graph>& graph,
2446     size_t outputs_index) {
2447   Node* opt_node = ONNXOptionalNodeForNone(graph);
2448   opt_node->insertBefore(graph->return_node());
2449   Value* graph_output = graph->outputs().at(outputs_index);
2450   // replace only the last value as Optional type only affects
2451   // the value right before output
2452   graph_output->replaceAllUsesAfterNodeWith(opt_node, opt_node->output());
2453   if (!graph_output->type()->cast<NoneType>()) {
2454     opt_node->addInput(graph_output);
2455     opt_node->copyMetadata(graph_output->node());
2456   }
2457 }
2458 
ONNXAssignOutputShape(std::shared_ptr<Graph> & graph,at::ArrayRef<at::Tensor> outputs,const python::IODescriptor & desc,bool onnx_shape_inference,bool is_script,int opset_version)2459 void ONNXAssignOutputShape(
2460     std::shared_ptr<Graph>& graph,
2461     at::ArrayRef<at::Tensor> outputs,
2462     const python::IODescriptor& desc,
2463     bool onnx_shape_inference,
2464     bool is_script,
2465     int opset_version) {
2466   size_t outputs_index = 0;
2467   PyObject* py_obj = unflatten(outputs, desc);
2468   TORCH_INTERNAL_ASSERT(PyTuple_Check(py_obj));
2469 
2470   outputs_index = ONNXAssignOutputShape(
2471       graph,
2472       outputs_index,
2473       py_obj,
2474       onnx_shape_inference,
2475       is_script,
2476       opset_version);
2477 
2478   TORCH_INTERNAL_ASSERT(
2479       outputs_index == graph->outputs().size(),
2480       "Incorrect number of elements provided as example outputs.");
2481 
2482   Py_DECREF(py_obj);
2483   GRAPH_DUMP("After ONNXAssignOutputShape", graph);
2484 }
2485 
ONNXShapeTypeInference(std::shared_ptr<Graph> & graph,const ParamMap & params_dict,int opset_version)2486 void ONNXShapeTypeInference(
2487     std::shared_ptr<Graph>& graph,
2488     const ParamMap& params_dict,
2489     int opset_version) {
2490   ConstantValueMap::ClearMaps();
2491   SetGraphInputTypeReliable(graph.get());
2492   ONNXShapeTypeInference(graph->block(), params_dict, opset_version);
2493   ConstantValueMap::ClearMaps();
2494 }
2495 
UpdateShapeConstantIfReliable(torch::jit::Value * node_output)2496 void UpdateShapeConstantIfReliable(torch::jit::Value* node_output) {
2497   if (ConstantValueMap::HasTypeReliable(node_output->debugName())) {
2498     auto reliable = ConstantValueMap::GetTypeReliable(node_output->debugName())
2499                         .value_or(false);
2500     if (reliable && !ConstantValueMap::HasShape(node_output->debugName())) {
2501       // TODO: ListType case
2502       if (auto output_tensor_type = node_output->type()->cast<TensorType>()) {
2503         if (output_tensor_type->dim()) {
2504           auto symbolic_sizes = output_tensor_type->symbolic_sizes();
2505           UpdateShapeConstantValueMap(node_output, symbolic_sizes);
2506         }
2507       }
2508     }
2509   }
2510 }
2511 
2512 } // namespace torch::jit
2513