1 #include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/onnx/constant_fold.h>
6 #include <torch/csrc/jit/passes/onnx/constant_map.h>
7 #include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
8 #include <torch/csrc/jit/passes/onnx/helper.h>
9 #include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
10 #include <torch/csrc/jit/python/python_arg_flatten.h>
11 #include <torch/csrc/jit/serialization/export.h>
12 #include <torch/csrc/jit/serialization/onnx.h>
13 #include <torch/csrc/utils/python_strings.h>
14
15 #include <torch/csrc/onnx/diagnostics/diagnostics.h>
16
17 #include <onnx/shape_inference/implementation.h>
18 #include <algorithm>
19 #include <cmath>
20 #include <iterator>
21 #include <limits>
22 #include <unordered_set>
23 #include <utility>
24
25 namespace torch::jit {
26
PyNone_Check(PyObject * o)27 inline bool PyNone_Check(PyObject* o) {
28 return o == Py_None;
29 }
30
MergeInferredType(const TypePtr & existing_type,const TypePtr & inferred_type)31 std::pair<TypePtr, bool> MergeInferredType(
32 const TypePtr& existing_type,
33 const TypePtr& inferred_type) {
34 auto new_list_type = inferred_type->cast<ListType>();
35 auto use_inferred_type = false;
36 if (new_list_type) {
37 return std::make_pair(inferred_type, true);
38 }
39 auto new_tensor_type = inferred_type->cast<TensorType>();
40 auto old_tensor_type = existing_type->cast<TensorType>();
41
42 if (new_tensor_type && old_tensor_type) {
43 if (!old_tensor_type->device()) {
44 // device not available means this is an invalid tensor type (most likely
45 // an empty one) return inferred type directly.
46 return std::make_pair(new_tensor_type, true);
47 }
48 auto type = old_tensor_type;
49 if (new_tensor_type->dim()) {
50 type = type->withSymbolicShapes(new_tensor_type->symbolic_sizes());
51 use_inferred_type = true;
52 }
53 if (new_tensor_type->scalarType().has_value()) {
54 type = type->withScalarType(new_tensor_type->scalarType());
55 use_inferred_type = true;
56 }
57 return std::make_pair(type, use_inferred_type);
58 }
59
60 if (old_tensor_type) {
61 return std::make_pair(existing_type, false);
62 }
63
64 auto old_list_type = existing_type->cast<ListType>();
65 if (new_tensor_type && old_list_type) {
66 if (new_tensor_type->sizes().isComplete()) {
67 return std::make_pair(inferred_type, true);
68 }
69 return std::make_pair(existing_type, false);
70 }
71
72 return std::make_pair(inferred_type, true);
73 }
74
MergeInferredTypeAndSetMap(Value * dest_v,const TypePtr & existing_type,const TypePtr & inferred_type)75 void MergeInferredTypeAndSetMap(
76 Value* dest_v,
77 const TypePtr& existing_type,
78 const TypePtr& inferred_type) {
79 auto [mergedType, inferred] = MergeInferredType(existing_type, inferred_type);
80 dest_v->setType(mergedType);
81 ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred);
82 }
83
84 namespace {
85 namespace onnx_torch = ::torch::onnx;
86 namespace onnx = ::ONNX_NAMESPACE;
87 namespace diagnostics = ::torch::onnx::diagnostics;
88
89 // SymbolDimMap is a Torch-to-ONNX shape look-up. This is built so it can be
90 // returned by the export function. During the export however, when we come
91 // across new ONNX shapes, the reverse look-up is needed. To avoid incurring
92 // a linear-time look-up, we maintain DimSymbolMap in parallel.
ONNXDimToShapeSymbol(const onnx::TensorShapeProto_Dimension & dim,SymbolDimMap & symbol_dim_map,DimSymbolMap & dim_symbol_map)93 c10::ShapeSymbol ONNXDimToShapeSymbol(
94 const onnx::TensorShapeProto_Dimension& dim,
95 SymbolDimMap& symbol_dim_map,
96 DimSymbolMap& dim_symbol_map) {
97 if (dim.has_dim_value()) {
98 return c10::ShapeSymbol::fromStaticSize(dim.dim_value());
99 }
100 std::optional<c10::ShapeSymbol> sym = std::nullopt;
101 if (dim.has_dim_param()) {
102 // If this param is already known, assign the same Symbol.
103 GRAPH_UPDATE("Got dim_param:", dim.dim_param());
104 auto maybe_symbol = dim_symbol_map.find(dim.dim_param());
105 if (maybe_symbol != dim_symbol_map.end()) {
106 sym = maybe_symbol->second;
107 }
108 }
109 if (!sym) {
110 sym = c10::ShapeSymbol::newSymbol();
111 // If dim.dim_param() is empty, no need to keep track
112 // because there won't be duplicates.
113 symbol_dim_map[sym.value()] = dim.dim_param();
114 dim_symbol_map[dim.dim_param()] = sym.value();
115 }
116 return sym.value();
117 }
118
TorchTensorTypeFromONNX(const onnx::TypeProto_Tensor & onnx_tensor_type,SymbolDimMap & symbol_dim_map,DimSymbolMap & dim_symbol_map)119 TensorTypePtr TorchTensorTypeFromONNX(
120 const onnx::TypeProto_Tensor& onnx_tensor_type,
121 SymbolDimMap& symbol_dim_map,
122 DimSymbolMap& dim_symbol_map) {
123 std::optional<at::ScalarType> scalar_type;
124 if (onnx_tensor_type.has_elem_type()) {
125 scalar_type = ONNXTypeToATenType(onnx_tensor_type.elem_type());
126 }
127
128 auto v_type = TensorType::create(
129 scalar_type,
130 at::kCPU,
131 c10::SymbolicShape(),
132 c10::VaryingShape<c10::Stride>{},
133 {});
134 if (onnx_tensor_type.has_shape()) {
135 std::vector<c10::ShapeSymbol> sizes;
136 const auto& onnx_shape = onnx_tensor_type.shape();
137
138 for (const auto i : c10::irange(onnx_shape.dim_size())) {
139 sizes.emplace_back(ONNXDimToShapeSymbol(
140 onnx_shape.dim(i), symbol_dim_map, dim_symbol_map));
141 }
142 v_type = TensorType::create(scalar_type, at::kCPU, sizes.size(), {});
143 v_type = v_type->withSymbolicShapes(c10::SymbolicShape(sizes));
144
145 if (v_type->sizes().concrete_sizes().has_value()) {
146 // Populate strides based on sizes info, if sizes are all static.
147 // Creating strides ensures yielding True for isCompleteTensor.
148 v_type = v_type->contiguous();
149 }
150 }
151
152 return v_type;
153 }
154
TorchListTypeFromONNX(const onnx::TypeProto_Sequence & onnx_sequence_type,SymbolDimMap & symbol_dim_map,DimSymbolMap & dim_symbol_map)155 ListTypePtr TorchListTypeFromONNX(
156 const onnx::TypeProto_Sequence& onnx_sequence_type,
157 SymbolDimMap& symbol_dim_map,
158 DimSymbolMap& dim_symbol_map) {
159 if (onnx_sequence_type.has_elem_type()) {
160 const auto& onnx_seq_elem_type = onnx_sequence_type.elem_type();
161 if (onnx_seq_elem_type.has_tensor_type()) {
162 const auto& onnx_tensor_type = onnx_seq_elem_type.tensor_type();
163 const auto v_tensor_type = TorchTensorTypeFromONNX(
164 onnx_tensor_type, symbol_dim_map, dim_symbol_map);
165 auto v_type = ListType::create(v_tensor_type);
166 return v_type;
167 }
168 }
169 return nullptr;
170 }
171
UpdateTorchValueByOnnxValueInfo(Value * v,const onnx::ValueInfoProto & p_info,SymbolDimMap & symbol_dim_map,DimSymbolMap & dim_symbol_map)172 void UpdateTorchValueByOnnxValueInfo(
173 Value* v,
174 const onnx::ValueInfoProto& p_info,
175 SymbolDimMap& symbol_dim_map,
176 DimSymbolMap& dim_symbol_map) {
177 if (!p_info.has_type()) {
178 return;
179 }
180
181 const auto& p_type = p_info.type();
182 if (p_type.has_tensor_type()) {
183 const auto torch_tensor_type = TorchTensorTypeFromONNX(
184 p_type.tensor_type(), symbol_dim_map, dim_symbol_map);
185 if (torch_tensor_type) {
186 MergeInferredTypeAndSetMap(v, v->type(), torch_tensor_type);
187 }
188 } else if (p_type.has_sequence_type()) {
189 const auto torch_list_type = TorchListTypeFromONNX(
190 p_type.sequence_type(), symbol_dim_map, dim_symbol_map);
191 if (torch_list_type) {
192 MergeInferredTypeAndSetMap(v, v->type(), torch_list_type);
193 }
194 }
195 }
196
IsValidONNXControlflowNode(const Node * n)197 bool IsValidONNXControlflowNode(const Node* n) {
198 // Skip when block size is zero. This is when the node is being created,
199 // and doesn't have subblocks attached yet. Run shape inference for these
200 // nodes later, when the subgraph has already completed shape inferencing.
201 auto node_kind = n->kind();
202 if (node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) {
203 if (n->blocks().empty()) {
204 return false;
205 }
206 }
207
208 return true;
209 }
210
IsValidONNXNode(const Node * n)211 bool IsValidONNXNode(const Node* n) {
212 auto node_kind = n->kind();
213
214 if (!node_kind.is_onnx()) {
215 // node kind is not ONNX, skipped.
216 return false;
217 }
218
219 if (!IsValidONNXControlflowNode(n)) {
220 return false;
221 }
222
223 for (auto b : n->blocks()) {
224 for (auto b_n : b->nodes()) {
225 if (!IsValidONNXNode(b_n)) {
226 return false;
227 }
228 }
229 }
230
231 return true;
232 }
233
CustomSettype(Node * node)234 bool CustomSettype(Node* node) {
235 // This is a helper function to decide if the non-ONNX node actually has
236 // custom setType from user
237 // Go through every symbolic_sizes and if any one of them is static, we say
238 // this is set by user. On the other hand, if all of them are * (dynamic), we
239 // take this node does not have given type, since unreliable nodes have *
240 // shape anyway.
241 auto all_output_has_type = [](Value* output) {
242 if (auto output_type = output->type()->cast<TensorType>()) {
243 if (auto sizes = output_type->symbolic_sizes().sizes()) {
244 return std::any_of(std::begin(*sizes), std::end(*sizes), [](auto size) {
245 return size.is_static();
246 });
247 }
248 }
249 return false;
250 };
251
252 return std::all_of(
253 node->outputs().begin(), node->outputs().end(), all_output_has_type);
254 }
255
CloneValueFromListConstruct(Value * v,const std::shared_ptr<Graph> & n_graph,int opset_version)256 Value* CloneValueFromListConstruct(
257 Value* v,
258 const std::shared_ptr<Graph>& n_graph,
259 int opset_version) {
260 auto lc_node = v->node();
261 TORCH_INTERNAL_ASSERT(lc_node->kind() == ::c10::prim::ListConstruct);
262 // In jit/passes/onnx/peephole.cpp::eraseListConstruct,
263 // prim::ListConstruct is converted to onnx::Concat. The conversion should
264 // eventually be moved to symbolic. For now, treat this operator as
265 // special case, and change from list type to tensor type. The scalar type
266 // is preserved. If the elemtype is Int, insert a onnx::Concat node into
267 // the graph.
268 TypePtr elem = v->type()->castRaw<ListType>()->getElementType();
269 std::optional<at::ScalarType> scalar_type = std::nullopt;
270 if (elem->cast<IntType>()) {
271 scalar_type = at::kLong;
272 if (isValidToTransformToONNXConcatNode(v->node())) {
273 auto concat_node = transformToONNXConcatNode(
274 n_graph.get(), v->node(), true, opset_version);
275 return concat_node->output();
276 }
277 } else if (elem->cast<FloatType>()) {
278 scalar_type = at::kFloat;
279 } else if (elem->cast<BoolType>()) {
280 scalar_type = at::kBool;
281 } else if (auto t_type = elem->cast<TensorType>()) {
282 scalar_type = t_type->scalarType();
283 }
284
285 auto input = n_graph->addInput();
286 if (scalar_type) {
287 auto v_type = TensorType::create(
288 scalar_type.value(),
289 at::kCPU,
290 c10::SymbolicShape(),
291 c10::VaryingShape<c10::Stride>{},
292 {});
293 input->setType(v_type);
294 }
295 return input;
296 }
297
298 // Clone the node n for the new graph.
CloneNodeToGraph(Node * n,std::shared_ptr<Graph> n_graph,const ParamMap & params_dict,int opset_version)299 Node* CloneNodeToGraph(
300 Node* n,
301 std::shared_ptr<Graph> n_graph,
302 const ParamMap& params_dict,
303 int opset_version) {
304 auto clone_node = n_graph->createClone(
305 n, [&n_graph, ¶ms_dict, opset_version](Value* v) {
306 auto v_n = v->node();
307 switch (v_n->kind()) {
308 case ::c10::prim::Constant:
309 case ::c10::onnx::Constant: {
310 // Clone the input if it is constant.
311 auto constant_n = n_graph->insertNode(
312 n_graph->createClone(v_n, [](Value* v) { return v; }));
313 return constant_n->output();
314 }
315 case ::c10::prim::ListConstruct: {
316 return CloneValueFromListConstruct(v, n_graph, opset_version);
317 }
318 case ::c10::prim::PackPadded: {
319 auto input = n_graph->addInput();
320 if (v == v_n->output(0)) {
321 // Only the first output requires this workaround.
322 // In `peephole` pass, user nodes are modified to consume the
323 // input instead.
324 input->copyMetadata(v_n->input(0));
325 } else {
326 input->copyMetadata(v);
327 }
328 return input;
329 }
330 default: {
331 // Try to lookup input value and insert it into the graph.
332 // If the input value is unknown, set it to graph input in the new
333 // graph, and copy over metadata, such as datatype and shape.
334 ::std::optional<at::Tensor> val = ::std::nullopt;
335 auto v0 = params_dict.find(v->debugName());
336 if (v0 != params_dict.end()) {
337 val = v0->second.toTensor();
338 } else {
339 val = ConstantValueMap::GetValue(v->debugName());
340 }
341
342 if (val.has_value()) {
343 return n_graph
344 ->insertNode(n_graph->create(::c10::onnx::Constant)
345 ->t_(attr::value, val.value()))
346 ->output();
347 }
348 auto input = n_graph->addInput();
349 input->copyMetadata(v);
350 return input;
351 }
352 }
353 });
354 return clone_node;
355 }
356
HasValidType(const TypePtr & type,const std::string & name)357 bool HasValidType(const TypePtr& type, const std::string& name) {
358 if (auto t_type = type->cast<TensorType>()) {
359 if (!t_type->scalarType().has_value()) {
360 GRAPH_UPDATE("Input ", name, " is missing tensor datatype.");
361 return false;
362 }
363 } else if (auto s_type = type->cast<ListType>()) {
364 auto e_type = s_type->getElementType();
365 return HasValidType(e_type, name);
366 } else if (auto o_type = type->cast<OptionalType>()) {
367 auto e_type = o_type->getElementType();
368 return HasValidType(e_type, name);
369 }
370 return true;
371 }
372
IsGraphValidForInference(const std::shared_ptr<Graph> & graph)373 bool IsGraphValidForInference(const std::shared_ptr<Graph>& graph) {
374 // Verify if every input has type (either Tensor, Sequence or Optional) and
375 // scalar type. This is a requirement for ONNX graph inputs.
376 for (auto in : graph->inputs()) {
377 return HasValidType(in->type(), in->debugName());
378 }
379 return true;
380 }
381
ConvertGraphToONNXProto(const std::shared_ptr<Graph> & graph,std::shared_ptr<onnx::ModelProto> & model_proto,SymbolDimMap & symbol_dim_map,DimSymbolMap & dim_symbol_map,int opset_version)382 void ConvertGraphToONNXProto(
383 const std::shared_ptr<Graph>& graph,
384 std::shared_ptr<onnx::ModelProto>& model_proto,
385 SymbolDimMap& symbol_dim_map,
386 DimSymbolMap& dim_symbol_map,
387 int opset_version) {
388 auto
389 [model_proto_tmp,
390 export_map,
391 new_symbol_dim_map,
392 val_use_external_data_format,
393 node_names] =
394 export_onnx(
395 graph,
396 {},
397 opset_version,
398 {},
399 false,
400 onnx_torch::OperatorExportTypes::ONNX,
401 true,
402 true,
403 {},
404 true,
405 false,
406 std::string());
407 model_proto = std::move(model_proto_tmp);
408 symbol_dim_map.insert(new_symbol_dim_map.begin(), new_symbol_dim_map.end());
409 for (const auto& pair : new_symbol_dim_map) {
410 dim_symbol_map[pair.second] = pair.first;
411 }
412 for (int i = 0; i < model_proto->graph().output_size(); ++i) {
413 model_proto->mutable_graph()->mutable_output(i)->clear_type();
414 }
415 }
416
ComputeConstantFolding(Node * n,int opset_version)417 std::optional<at::Tensor> ComputeConstantFolding(Node* n, int opset_version) {
418 if (n->inputs().empty()) {
419 return std::nullopt;
420 }
421 std::vector<at::Tensor> inputTensorValues;
422 for (auto i : c10::irange(n->inputs().size())) {
423 if (TensorTypePtr input_type = n->input(i)->type()->cast<TensorType>()) {
424 if (!ConstantValueMap::HasValue(n->input(i)->debugName())) {
425 return std::nullopt;
426 }
427 auto tensor_value =
428 ConstantValueMap::GetValue(n->input(i)->debugName()).value();
429 inputTensorValues.emplace_back(tensor_value);
430 }
431 }
432 if (inputTensorValues.size() < n->inputs().size()) {
433 return std::nullopt;
434 }
435 try {
436 return onnx_constant_fold::runTorchBackendForOnnx(
437 n, inputTensorValues, opset_version);
438 } catch (const std::exception& ex) {
439 auto ex_str = std::string(ex.what());
440 ex_str = ex_str.substr(0, ex_str.find('\n'));
441 TORCH_WARN("Constant folding in symbolic shape inference fails: ", ex_str);
442 return std::nullopt;
443 }
444 }
445
446 // Similar to the function above, but for symbolic shapes.
ComputeShapeFromReshape(Node * n,const c10::SymbolicShape & input_shape,const c10::SymbolicShape & shape,int opset_version)447 std::optional<::c10::SymbolicShape> ComputeShapeFromReshape(
448 Node* n,
449 const c10::SymbolicShape& input_shape,
450 const c10::SymbolicShape& shape,
451 int opset_version) {
452 std::vector<c10::ShapeSymbol> input_shape_vector =
453 input_shape.sizes().value();
454 std::vector<c10::ShapeSymbol> shape_vector = shape.sizes().value();
455 TORCH_INTERNAL_ASSERT(
456 !input_shape_vector.empty() || !shape_vector.empty(),
457 "Reshape node should have at least one input size > 0 when constant folding.");
458 if (shape_vector.empty()) {
459 return input_shape;
460 }
461 if (input_shape_vector.empty()) {
462 return shape;
463 }
464
465 auto is_zero = [](c10::ShapeSymbol& ss) { return ss.value() == 0; };
466 auto it_0 = std::find_if(shape_vector.begin(), shape_vector.end(), is_zero);
467 bool shape_has_zero = it_0 != shape_vector.end();
468
469 int minus_one_pos = -1;
470 for (auto i : c10::irange(shape_vector.size())) {
471 if (shape_vector[i].value() == -1) {
472 minus_one_pos = i;
473 break;
474 }
475 }
476
477 int allowzero = 0;
478 if (opset_version >= 14 && n->hasAttributeS("allowzero")) {
479 allowzero = n->i(attr::allowzero);
480 }
481
482 TORCH_CHECK(
483 !(shape_has_zero && allowzero == 1 && minus_one_pos != -1),
484 "0 and -1 cannot both be present in `Shape` input of `Reshape` node, when `allowzero=1`.");
485
486 if (minus_one_pos == -1 && (!shape_has_zero || allowzero)) {
487 return shape;
488 }
489 std::vector<c10::ShapeSymbol> final_shape;
490 uint64_t shape_ratio = 1;
491 std::unordered_map<int64_t, int64_t> sym_map;
492 for (const c10::ShapeSymbol& input_shape : input_shape_vector) {
493 // input_shape.static_size() could be zero when torch.tensor([]) is used.
494 if (input_shape.is_static() && input_shape.static_size() != 0) {
495 if (shape_ratio >=
496 std::numeric_limits<uint64_t>::max() / input_shape.static_size()) {
497 TORCH_WARN(
498 "ComputeShapeFromReshape(), shape_ratio overflows, skip shape inference.");
499 return std::nullopt;
500 } else {
501 shape_ratio *= static_cast<uint64_t>(input_shape.static_size());
502 }
503 } else {
504 auto value = input_shape.value();
505 sym_map.emplace(value, 0).first->second += 1;
506 }
507 }
508 int shape_size = static_cast<int>(shape_vector.size());
509 for (const int i : c10::irange(shape_size)) {
510 if (i == minus_one_pos) {
511 continue;
512 }
513 c10::ShapeSymbol& target_shape = shape_vector[i];
514 if (target_shape.value() == 0) {
515 target_shape = input_shape_vector[i];
516 }
517 if (target_shape.is_static()) {
518 shape_ratio /= static_cast<uint64_t>(target_shape.static_size());
519 } else {
520 auto value = target_shape.value();
521 if (sym_map.find(value) == sym_map.end()) {
522 return std::nullopt;
523 }
524 sym_map[value]--;
525 if (sym_map[value] == 0) {
526 sym_map.erase(value);
527 }
528 }
529 }
530
531 // sym_map is used to match shape symbols between the input and shape.
532 // If there is a mismatch, the output shape cannot be estimated.
533 if (!sym_map.empty()) {
534 return std::nullopt;
535 }
536
537 TORCH_INTERNAL_ASSERT(
538 minus_one_pos != -1,
539 "There are no examples for shape_has_zero = true && minus_one_pos == -1.");
540
541 for (const auto i : c10::irange(minus_one_pos)) {
542 c10::ShapeSymbol cur_shape(
543 shape_vector[i].value() == 0 ? input_shape_vector[i] : shape_vector[i]);
544 final_shape.push_back(cur_shape);
545 }
546 if (minus_one_pos != -1) {
547 final_shape.push_back(
548 c10::ShapeSymbol::fromStaticSize(static_cast<int64_t>(shape_ratio)));
549 }
550 for (auto i = minus_one_pos + 1; i < shape_size; i++) {
551 c10::ShapeSymbol cur_shape(
552 shape_vector[i].value() == 0 ? input_shape_vector[i] : shape_vector[i]);
553 final_shape.push_back(cur_shape);
554 }
555 c10::SymbolicShape final_shape_0(final_shape);
556 return final_shape_0;
557 }
558
ComputeShapeFromExpand(const std::vector<::c10::ShapeSymbol> & input_shape,const std::vector<int64_t> & reshape)559 std::optional<::c10::SymbolicShape> ComputeShapeFromExpand(
560 const std::vector<::c10::ShapeSymbol>& input_shape,
561 const std::vector<int64_t>& reshape) {
562 for (const auto& it : reshape) {
563 if (it < 0) {
564 return std::nullopt;
565 }
566 }
567 std::vector<::c10::ShapeSymbol> final_shape;
568 if (input_shape.size() >= reshape.size()) {
569 final_shape = input_shape;
570 } else {
571 for (auto v : reshape) {
572 final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize(v));
573 }
574 }
575 auto min_size = std::min(input_shape.size(), reshape.size());
576 for (const auto i : c10::irange(min_size)) {
577 auto idx = final_shape.size() - i - 1;
578 auto input_shape_idx = input_shape.size() - i - 1;
579 auto reshape_idx = reshape.size() - i - 1;
580 if (input_shape[input_shape_idx].is_static()) {
581 auto input_shape_value = input_shape[input_shape_idx].static_size();
582 auto reshape_value = reshape[reshape_idx];
583 TORCH_INTERNAL_ASSERT(
584 input_shape_value == reshape_value || input_shape_value == 1 ||
585 reshape_value == 1,
586 "ONNX Expand input shape constraint not satisfied.");
587 final_shape[idx] = ::c10::ShapeSymbol::fromStaticSize(
588 std::max(input_shape_value, reshape_value));
589
590 } else {
591 final_shape[idx] = ::c10::ShapeSymbol::newSymbol();
592 }
593 }
594 ::c10::SymbolicShape shape(final_shape);
595 return shape;
596 }
597
ComputeShapeFromTile(const std::vector<::c10::ShapeSymbol> & input_shape,const std::vector<int64_t> & reshape)598 std::optional<::c10::SymbolicShape> ComputeShapeFromTile(
599 const std::vector<::c10::ShapeSymbol>& input_shape,
600 const std::vector<int64_t>& reshape) {
601 TORCH_INTERNAL_ASSERT(
602 input_shape.size() == reshape.size(),
603 "ONNX Tile input shapes do not match.");
604 for (const auto& it : reshape) {
605 if (it < 0) {
606 return std::nullopt;
607 }
608 }
609 std::vector<::c10::ShapeSymbol> final_shape;
610 final_shape.reserve(input_shape.size());
611 for (const auto i : c10::irange(input_shape.size())) {
612 if (input_shape[i].is_static()) {
613 final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize(
614 input_shape[i].static_size() * reshape[i]));
615 } else {
616 final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
617 }
618 }
619 ::c10::SymbolicShape shape(final_shape);
620 return shape;
621 }
622
UpdateRank(Value * value,size_t rank)623 void UpdateRank(Value* value, size_t rank) {
624 ConstantValueMap::SetRank(value->debugName(), rank);
625 if (TensorTypePtr value_type = value->type()->cast<TensorType>()) {
626 std::optional<size_t> rank_opt = rank;
627 auto shape = ::c10::SymbolicShape(rank_opt);
628 value->setType(value_type->withSymbolicShapes(shape));
629 }
630 }
631
UpdateShapeFromVector(Value * value,const std::vector<int64_t> & shape_size)632 void UpdateShapeFromVector(
633 Value* value,
634 const std::vector<int64_t>& shape_size) {
635 ::c10::SymbolicShape shape(shape_size);
636 ConstantValueMap::SetShape(value->debugName(), shape);
637 if (shape_size.empty()) {
638 UpdateRank(value, 0);
639 return;
640 }
641 ConstantValueMap::SetRank(value->debugName(), shape_size.size());
642 if (TensorTypePtr value_type = value->type()->cast<TensorType>()) {
643 value->setType(value_type->withSymbolicShapes(shape));
644 }
645 }
646
UpdateShape(Value * value,const::c10::SymbolicShape & shape)647 void UpdateShape(Value* value, const ::c10::SymbolicShape& shape) {
648 ConstantValueMap::SetShape(value->debugName(), shape);
649 if (shape.rank().has_value()) {
650 auto rank = shape.rank().value();
651 if (rank == 0) {
652 UpdateRank(value, 0);
653 return;
654 }
655 ConstantValueMap::SetRank(value->debugName(), rank);
656 if (TensorTypePtr value_type = value->type()->cast<TensorType>()) {
657 value->setType(value_type->withSymbolicShapes(shape));
658 }
659 }
660 }
661
UpdateShapeConstantValueMap(const Value * value,const::c10::SymbolicShape & shape)662 void UpdateShapeConstantValueMap(
663 const Value* value,
664 const ::c10::SymbolicShape& shape) {
665 ConstantValueMap::SetShape(value->debugName(), shape);
666 if (shape.rank().has_value()) {
667 auto rank = shape.rank().value();
668 ConstantValueMap::SetRank(value->debugName(), rank);
669 }
670 }
671
GetValueFromListConstructNode(Node * lc_node)672 std::optional<std::vector<int64_t>> GetValueFromListConstructNode(
673 Node* lc_node) {
674 std::vector<int64_t> shape_size;
675 for (const auto& input : lc_node->inputs()) {
676 if (input->type()->cast<TensorType>() &&
677 ConstantValueMap::HasValue(input->debugName())) {
678 auto lc_value = ConstantValueMap::GetValue(input->debugName()).value();
679 if (lc_value.dim() == 0) {
680 int64_t lc_value_0 = lc_value.item<int64_t>();
681 shape_size.emplace_back(lc_value_0);
682 }
683 }
684 }
685 return lc_node->inputs().size() == shape_size.size()
686 ? std::optional<std::vector<int64_t>>(shape_size)
687 : std::nullopt;
688 }
689
SetShapeValueFromListConstructNode(Node * lc_node)690 void SetShapeValueFromListConstructNode(Node* lc_node) {
691 std::vector<c10::ShapeSymbol> shape_size;
692 for (const auto& input : lc_node->inputs()) {
693 if (TensorTypePtr shape_type = input->type()->cast<TensorType>()) {
694 if (ConstantValueMap::HasValue(input->debugName())) {
695 auto lc_value = ConstantValueMap::GetValue(input->debugName()).value();
696 if (lc_value.dim() == 0) {
697 int64_t lc_value_0 = lc_value.item<int64_t>();
698 shape_size.emplace_back(c10::ShapeSymbol::fromStaticSize(lc_value_0));
699 }
700 } else if (ConstantValueMap::HasShapeValue(input->debugName())) {
701 auto lc_value =
702 ConstantValueMap::GetShapeValue(input->debugName()).value();
703 if (lc_value.rank() == 1U) {
704 shape_size.emplace_back(lc_value.at(0));
705 }
706 }
707 }
708 }
709 if (lc_node->inputs().size() == shape_size.size()) {
710 c10::SymbolicShape final_shape(shape_size);
711 ConstantValueMap::SetShapeValue(
712 lc_node->output()->debugName(), final_shape);
713 }
714 }
715
Broadcast(const std::vector<::c10::ShapeSymbol> & input_shape_value_0,const std::vector<::c10::ShapeSymbol> & input_shape_value_1)716 std::vector<::c10::ShapeSymbol> Broadcast(
717 const std::vector<::c10::ShapeSymbol>& input_shape_value_0,
718 const std::vector<::c10::ShapeSymbol>& input_shape_value_1) {
719 size_t rank_0 = input_shape_value_0.size();
720 size_t rank_1 = input_shape_value_1.size();
721 size_t rank_max = std::max(rank_0, rank_1);
722 size_t rank_min = std::min(rank_0, rank_1);
723 std::vector<::c10::ShapeSymbol> final_shape;
724 final_shape.reserve(rank_max);
725 std::generate_n(
726 std::back_inserter(final_shape), rank_max, ::c10::ShapeSymbol::newSymbol);
727 for (auto idx : c10::irange(rank_min)) {
728 const c10::ShapeSymbol& ss_shape_0 = input_shape_value_0[rank_0 - 1 - idx];
729 const c10::ShapeSymbol& ss_shape_1 = input_shape_value_1[rank_1 - 1 - idx];
730 bool is_static_0 = ss_shape_0.is_static();
731 bool is_static_1 = ss_shape_1.is_static();
732 size_t shape_idx = rank_max - 1 - idx;
733 if (is_static_0 && is_static_1) {
734 int64_t static_0_sz = ss_shape_0.static_size();
735 int64_t static_1_sz = ss_shape_1.static_size();
736 // condition for corner case of 0d tensor
737 // 0d tensor with 1d tensor would give us 0d tensor
738 if (std::min(static_0_sz, static_1_sz) == 0) {
739 final_shape[shape_idx] = ::c10::ShapeSymbol::fromStaticSize(
740 std::min(static_0_sz, static_1_sz));
741 } else {
742 final_shape[shape_idx] = ::c10::ShapeSymbol::fromStaticSize(
743 std::max(static_0_sz, static_1_sz));
744 }
745 } else if (!is_static_0 && !is_static_1) {
746 if (ss_shape_0.value() == ss_shape_1.value()) {
747 final_shape[shape_idx] = ss_shape_0;
748 }
749 }
750 }
751 if (rank_0 < rank_1) {
752 for (size_t idx = rank_min; idx < rank_max; idx++) {
753 size_t shape_idx = rank_max - 1 - idx;
754 final_shape[shape_idx] = input_shape_value_1[shape_idx];
755 }
756 } else {
757 for (size_t idx = rank_min; idx < rank_max; idx++) {
758 size_t shape_idx = rank_max - 1 - idx;
759 final_shape[shape_idx] = input_shape_value_0[shape_idx];
760 }
761 }
762 return final_shape;
763 }
764
ProcessBroadcastNode(Node * n)765 void ProcessBroadcastNode(Node* n) {
766 TORCH_INTERNAL_ASSERT(n->inputs().size() == 2);
767 if (ConstantValueMap::HasShape(n->input(0)->debugName()) &&
768 ConstantValueMap::HasShape(n->input(1)->debugName())) {
769 auto input_shape_0 = ConstantValueMap::GetShape(n->input(0)->debugName());
770 auto input_shape_value_0 = input_shape_0.value().sizes().value();
771 auto input_shape_1 = ConstantValueMap::GetShape(n->input(1)->debugName());
772 auto input_shape_value_1 = input_shape_1.value().sizes().value();
773 auto final_shape = Broadcast(input_shape_value_0, input_shape_value_1);
774 UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
775 }
776 }
777
ProcessShapeForConcatNode(Node * n)778 void ProcessShapeForConcatNode(Node* n) {
779 int axis = n->i(attr::axis);
780 if (ConstantValueMap::HasRank(n->input(0)->debugName())) {
781 auto rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value();
782 size_t axis_adjust = 0;
783 if (axis >= 0) {
784 axis_adjust = static_cast<size_t>(axis);
785 } else {
786 axis_adjust = static_cast<size_t>(axis + static_cast<int>(rank));
787 }
788 std::vector<::c10::ShapeSymbol> final_shape;
789 final_shape.reserve(rank);
790 for (auto idx : c10::irange(rank)) {
791 if (idx == axis_adjust) {
792 auto flag = true;
793 int64_t size_total = 0;
794 for (auto input_idx : c10::irange(n->inputs().size())) {
795 if (ConstantValueMap::HasShape(n->input(input_idx)->debugName())) {
796 auto input_shape =
797 ConstantValueMap::GetShape(n->input(input_idx)->debugName());
798 auto input_shape_value = input_shape.value().sizes();
799 auto shape_symbol = input_shape_value.value()[idx];
800 if (shape_symbol.is_static()) {
801 size_total += shape_symbol.static_size();
802 } else {
803 flag = false;
804 break;
805 }
806 }
807 }
808 if (flag) {
809 final_shape.emplace_back(
810 ::c10::ShapeSymbol::fromStaticSize(size_total));
811 } else {
812 final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
813 }
814 } else {
815 auto flag = false;
816 for (auto input_idx : c10::irange(n->inputs().size())) {
817 if (ConstantValueMap::HasShape(n->input(input_idx)->debugName())) {
818 auto input_shape =
819 ConstantValueMap::GetShape(n->input(input_idx)->debugName());
820 auto input_shape_value = input_shape.value().sizes();
821 auto shape_symbol = input_shape_value.value()[idx];
822 if (shape_symbol.is_static()) {
823 final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize(
824 shape_symbol.static_size()));
825 flag = true;
826 break;
827 }
828 }
829 }
830 if (!flag) {
831 final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
832 }
833 }
834 }
835 UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
836 }
837 }
838
ProcessShapeValueForConcatNode(Node * n)839 void ProcessShapeValueForConcatNode(Node* n) {
840 auto rank = n->inputs().size();
841 std::vector<c10::ShapeSymbol> shape_size;
842 for (const auto& input : n->inputs()) {
843 if (ConstantValueMap::HasValue(input->debugName())) {
844 auto concat_value =
845 ConstantValueMap::GetValue(input->debugName()).value();
846 if (concat_value.dim() == 1 && concat_value.size(0) == 1) {
847 auto concat_value_0 = concat_value[0].item<int64_t>();
848 shape_size.emplace_back(
849 c10::ShapeSymbol::fromStaticSize(concat_value_0));
850 }
851 } else if (ConstantValueMap::HasShapeValue(input->debugName())) {
852 auto concat_value =
853 ConstantValueMap::GetShapeValue(input->debugName()).value();
854 if (concat_value.rank() == 1U) {
855 shape_size.emplace_back(concat_value.at(0));
856 }
857 }
858 }
859 if (rank == shape_size.size()) {
860 c10::SymbolicShape final_shape(shape_size);
861 ConstantValueMap::SetShapeValue(n->output(0)->debugName(), final_shape);
862 }
863 }
864
ProcessConcatNode(Node * n)865 void ProcessConcatNode(Node* n) {
866 ProcessShapeForConcatNode(n);
867 ProcessShapeValueForConcatNode(n);
868 }
869
ProcessMatMulNode(Node * n)870 void ProcessMatMulNode(Node* n) {
871 if (ConstantValueMap::HasShape(n->input(0)->debugName()) &&
872 ConstantValueMap::HasShape(n->input(1)->debugName())) {
873 auto input_shape_0 =
874 ConstantValueMap::GetShape(n->input(0)->debugName()).value();
875 auto input_shape_value_0 = input_shape_0.sizes().value();
876 auto input_shape_1 =
877 ConstantValueMap::GetShape(n->input(1)->debugName()).value();
878 auto input_shape_value_1 = input_shape_1.sizes().value();
879 size_t rank_0 = input_shape_value_0.size();
880 size_t rank_1 = input_shape_value_1.size();
881 // Handle inputs of rank 1 just like numpy.matmul:
882 // https://numpy.org/doc/stable/reference/generated/numpy.matmul.html
883 auto is_rank_0_1 = false;
884 if (rank_0 == 1) {
885 input_shape_value_0.insert(
886 input_shape_value_0.begin(), ::c10::ShapeSymbol::fromStaticSize(1));
887 rank_0 = 2;
888 is_rank_0_1 = true;
889 }
890 auto is_rank_1_1 = false;
891 if (rank_1 == 1) {
892 input_shape_value_1.emplace_back(::c10::ShapeSymbol::fromStaticSize(1));
893 rank_1 = 2;
894 is_rank_1_1 = true;
895 }
896 // Per https://pytorch.org/docs/stable/generated/torch.matmul.html
897 // the broadcasting logic only applies to the batch dimensions, and not the
898 // matrix dimensions so we remove the matrix dimensions which are the last 2
899 // dimensions before broadcasting
900 auto final_shape = Broadcast(
901 std::vector<::c10::ShapeSymbol>(
902 input_shape_value_0.begin(), input_shape_value_0.end() - 2),
903 std::vector<::c10::ShapeSymbol>(
904 input_shape_value_1.begin(), input_shape_value_1.end() - 2));
905 // add the last 2 dimensions back, unless they do not exist in the first
906 // place and inserted by this function Then apply [n,k]X[k,m]=[n,m], where
907 // n=input_shape_value_0[rank_0 - 2], m=input_shape_value_1[rank_1 - 1]
908 if (!is_rank_0_1) {
909 final_shape.emplace_back(input_shape_value_0[rank_0 - 2]);
910 }
911 if (!is_rank_1_1) {
912 final_shape.emplace_back(input_shape_value_1[rank_1 - 1]);
913 }
914 UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
915 }
916 }
917
ProcessReduceNode(Node * n)918 void ProcessReduceNode(Node* n) {
919 if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
920 auto input_shape_0 = ConstantValueMap::GetShape(n->input(0)->debugName());
921 auto input_shape_value_0 = input_shape_0.value().sizes();
922 size_t rank_0 = input_shape_value_0.value().size();
923 std::vector<::c10::ShapeSymbol> final_shape;
924 std::vector<int64_t> axes_vector(rank_0);
925 if (n->hasAttributeS("axes")) {
926 axes_vector = n->is(attr::axes);
927 } else if (n->inputs().size() > 1) {
928 axes_vector =
929 ConstantValueMap::GetValueInto1DInt64Vector(n->input(1)->debugName());
930 } else {
931 std::iota(axes_vector.begin(), axes_vector.end(), 0);
932 }
933
934 for (auto idx : c10::irange(axes_vector.size())) {
935 if (axes_vector[idx] < 0) {
936 axes_vector[idx] += rank_0;
937 }
938 }
939 final_shape.reserve(rank_0);
940 // ONNX keepdims defaults to 1 when not set.
941 int64_t keepdims = 1;
942 if (n->hasAttributeS("keepdims")) {
943 keepdims = n->i(attr::keepdims);
944 }
945 for (auto idx : c10::irange(rank_0)) {
946 auto it = std::find(axes_vector.begin(), axes_vector.end(), idx);
947 if (it != axes_vector.end()) {
948 if (keepdims != 0) {
949 final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize(1));
950 }
951 } else {
952 final_shape.emplace_back(input_shape_value_0.value()[idx]);
953 }
954 }
955 UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
956 }
957 }
958
ProcessReshapeNode(Node * n,int opset_version)959 void ProcessReshapeNode(Node* n, int opset_version) {
960 const auto& input_name = n->input(0)->debugName();
961 const auto& shape_name = n->input(1)->debugName();
962
963 // When `shape` input value is statically known, compute output shape.
964 if (ConstantValueMap::HasValue(shape_name)) {
965 auto static_shape_value =
966 ConstantValueMap::GetValueInto1DInt64Vector(shape_name);
967 auto symbolic_input_shape = ConstantValueMap::GetShape(input_name);
968 if (symbolic_input_shape && !static_shape_value.empty()) {
969 auto final_shape = ComputeShapeFromReshape(
970 n,
971 symbolic_input_shape.value(),
972 c10::SymbolicShape(static_shape_value),
973 opset_version);
974 if (final_shape) {
975 UpdateShape(n->output(), final_shape.value());
976 return;
977 }
978 }
979 }
980
981 // When `shape` input value is symbolically known, compute output shape.
982 if (ConstantValueMap::HasShapeValue(shape_name) &&
983 ConstantValueMap::HasShape(input_name)) {
984 auto symbolic_input_shape = ConstantValueMap::GetShape(input_name).value();
985 auto symbolic_shape_value =
986 ConstantValueMap::GetShapeValue(shape_name).value();
987 auto final_shape = ComputeShapeFromReshape(
988 n, symbolic_input_shape, symbolic_shape_value, opset_version);
989 if (final_shape.has_value()) {
990 UpdateShape(n->output(), final_shape.value());
991 return;
992 }
993 }
994
995 // Only shape of new shape is known, assign output rank.
996 if (ConstantValueMap::HasShape(shape_name)) {
997 auto output_rank = ConstantValueMap::GetShapeInto1DInt64Vector(shape_name);
998 if (output_rank.has_value()) {
999 TORCH_INTERNAL_ASSERT(output_rank.value().size() == 1);
1000 UpdateRank(n->output(), output_rank.value()[0]);
1001 return;
1002 }
1003 }
1004
1005 // ListConstruct is handled at the beginning of ProcessConstantValueMap, no
1006 // further process here.
1007 if (TensorTypePtr shape_type = n->input(1)->type()->cast<TensorType>()) {
1008 // Set rank to Reshape output if possible.
1009 // From shape inference, we have:
1010 // %4236 : Float(*, device=cpu) = onnx::Transpose[perm=[0]](%4235)
1011 // %4237 : Long(2, strides=[1], device=cpu) = onnx::Concat[axis=0](%4232)
1012 // %4238 : FloatTensor(device=cpu) = onnx::Reshape(%4236, %4237)
1013 // We can have it as SymbolicShape with known rank:
1014 // %4238 : Float(*, *, strides=[2480, 1], requires_grad=0, device=cpu) =
1015 // onnx::Reshape(%4236, %4237)
1016 auto shape_type_dim = shape_type->dim();
1017 if (shape_type_dim.has_value()) {
1018 auto shape_type_size = shape_type->sizes()[0];
1019 if (shape_type_size.has_value()) {
1020 size_t rank = shape_type_size.value();
1021 UpdateRank(n->output(), rank);
1022 }
1023 }
1024 }
1025 }
1026
ComputeShapeForSlice(const std::vector<c10::ShapeSymbol> & input_shape,const std::vector<int64_t> & start_vector,const std::vector<int64_t> & end_vector,const std::vector<int64_t> & axes_vector,const std::vector<int64_t> & step_vector)1027 c10::SymbolicShape ComputeShapeForSlice(
1028 const std::vector<c10::ShapeSymbol>& input_shape,
1029 const std::vector<int64_t>& start_vector,
1030 const std::vector<int64_t>& end_vector,
1031 const std::vector<int64_t>& axes_vector,
1032 const std::vector<int64_t>& step_vector) {
1033 TORCH_INTERNAL_ASSERT(axes_vector.size() <= input_shape.size());
1034 TORCH_INTERNAL_ASSERT(axes_vector.size() == start_vector.size());
1035 TORCH_INTERNAL_ASSERT(axes_vector.size() == end_vector.size());
1036 TORCH_INTERNAL_ASSERT(axes_vector.size() == step_vector.size());
1037 std::vector<c10::ShapeSymbol> final_shape;
1038 final_shape = input_shape;
1039 for (const auto idx : c10::irange(axes_vector.size())) {
1040 auto axis = axes_vector[idx];
1041 TORCH_INTERNAL_ASSERT(axis >= 0);
1042 if (!input_shape[axis].is_static()) {
1043 final_shape[axis] = c10::ShapeSymbol::newSymbol();
1044 continue;
1045 }
1046 auto input_shape_axis_value = input_shape[axis].static_size();
1047 auto cur_start = start_vector[idx];
1048 auto cur_end = end_vector[idx];
1049 auto cur_step = step_vector[idx];
1050 if (cur_start < -input_shape_axis_value) {
1051 cur_start = 0;
1052 } else if (cur_start < 0) {
1053 cur_start = input_shape_axis_value + cur_start;
1054 } else if (cur_start > input_shape_axis_value - 1) {
1055 cur_start = input_shape_axis_value;
1056 }
1057 if (cur_end < -input_shape_axis_value) {
1058 cur_end = -1;
1059 } else if (cur_end < 0) {
1060 cur_end = input_shape_axis_value + cur_end;
1061 } else if (cur_end > input_shape_axis_value - 1) {
1062 cur_end = input_shape_axis_value;
1063 }
1064 TORCH_INTERNAL_ASSERT(cur_step != 0);
1065 if (cur_step > 0) {
1066 final_shape[axis] = c10::ShapeSymbol::fromStaticSize(
1067 (cur_end - cur_start - 1) / cur_step + 1);
1068 } else {
1069 final_shape[axis] = c10::ShapeSymbol::fromStaticSize(
1070 (cur_start - cur_end - 1) / (-cur_step) + 1);
1071 }
1072 }
1073 return c10::SymbolicShape(final_shape);
1074 }
1075
ProcessSliceNode(Node * n,int opset_version)1076 void ProcessSliceNode(Node* n, int opset_version) {
1077 bool valid = ConstantValueMap::HasShape(n->input(0)->debugName());
1078
1079 // For opset version <= 9, starts, ends, axes, steps are attributes,
1080 // so their values are always valid.
1081 if (opset_version >= 10) {
1082 // We can only infer shapes if 'axes' is known.
1083 if (n->inputs().size() > 3) {
1084 valid = valid && ConstantValueMap::HasValue(n->input(3)->debugName());
1085 }
1086 }
1087
1088 if (!valid) {
1089 if (ConstantValueMap::HasRank(n->input(0)->debugName())) {
1090 auto rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value();
1091 UpdateRank(n->output(), rank);
1092 }
1093 return;
1094 } else {
1095 auto shape_size_0 =
1096 ConstantValueMap::GetShape(n->input(0)->debugName()).value();
1097 if (shape_size_0.rank().has_value()) {
1098 auto input0_shape_value = shape_size_0.sizes().value();
1099
1100 std::vector<int64_t> start_vector;
1101 std::vector<int64_t> end_vector;
1102 std::vector<int64_t> step_vector;
1103
1104 std::vector<int64_t> axes_vector(input0_shape_value.size(), 0);
1105 for (const auto i : c10::irange(input0_shape_value.size())) {
1106 axes_vector[i] = i;
1107 }
1108 if (opset_version >= 10 && n->inputs().size() > 3) {
1109 axes_vector = ConstantValueMap::GetValueInto1DInt64Vector(
1110 n->input(3)->debugName());
1111 } else if (opset_version < 10 && n->hasAttributeS("axes")) {
1112 axes_vector = n->is(attr::axes);
1113 }
1114 for (auto& axis : axes_vector) {
1115 if (axis < 0) {
1116 axis += input0_shape_value.size();
1117 }
1118 }
1119
1120 if (opset_version < 10) {
1121 start_vector = n->is(attr::starts);
1122 end_vector = n->is(attr::ends);
1123 } else {
1124 // If starts, ends, or step are unknown,
1125 // then mark all dimensions in 'axes' as unknown.
1126 std::vector<uint64_t> indices = {1U, 2U, 4U};
1127 bool start_end_step_known =
1128 std::all_of(indices.begin(), indices.end(), [&n](auto i) {
1129 return (i >= n->inputs().size()) ||
1130 ConstantValueMap::HasValue(n->input(i)->debugName());
1131 });
1132 if (!start_end_step_known) {
1133 auto final_shape = input0_shape_value;
1134 for (const auto axis : axes_vector) {
1135 final_shape[axis] = c10::ShapeSymbol::newSymbol();
1136 }
1137 UpdateShape(n->output(), final_shape);
1138 return;
1139 }
1140
1141 start_vector = ConstantValueMap::GetValueInto1DInt64Vector(
1142 n->input(1)->debugName());
1143 end_vector = ConstantValueMap::GetValueInto1DInt64Vector(
1144 n->input(2)->debugName());
1145 if (n->inputs().size() > 4) {
1146 step_vector = ConstantValueMap::GetValueInto1DInt64Vector(
1147 n->input(4)->debugName());
1148 }
1149 }
1150
1151 if (step_vector.empty()) {
1152 step_vector = std::vector<int64_t>(axes_vector.size(), 1);
1153 }
1154
1155 auto final_shape = ComputeShapeForSlice(
1156 input0_shape_value,
1157 start_vector,
1158 end_vector,
1159 axes_vector,
1160 step_vector);
1161 UpdateShape(n->output(), final_shape);
1162 }
1163 }
1164 }
1165
ProcessUnchangeNode(Node * n)1166 void ProcessUnchangeNode(Node* n) {
1167 if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
1168 auto shape_size_0 =
1169 ConstantValueMap::GetShape(n->input(0)->debugName()).value();
1170 UpdateShape(n->output(), shape_size_0);
1171 }
1172 }
1173
ProcessTimeSeriesNode(Node * n)1174 void ProcessTimeSeriesNode(Node* n) {
1175 auto input0_shape = ConstantValueMap::GetShape(n->input(0)->debugName());
1176 auto input1_shape = ConstantValueMap::GetShape(n->input(1)->debugName());
1177 if (!(input0_shape.has_value() && input1_shape.has_value())) {
1178 return;
1179 }
1180 auto input0_shape_value = input0_shape.value().sizes();
1181 auto input1_shape_value = input1_shape.value().sizes();
1182 c10::ShapeSymbol seq_length;
1183 c10::ShapeSymbol num_directions;
1184 c10::ShapeSymbol batch_size;
1185 c10::ShapeSymbol hidden_size;
1186 if (input0_shape_value.has_value()) {
1187 seq_length = input0_shape_value.value()[0];
1188 batch_size = input0_shape_value.value()[1];
1189 }
1190
1191 if (input1_shape_value.has_value()) {
1192 num_directions = input1_shape_value.value()[0];
1193 if (input1_shape_value.value()[1].is_static()) {
1194 auto input1_value = input1_shape_value.value()[1].static_size();
1195 switch (n->kind()) {
1196 case ::c10::onnx::RNN:
1197 hidden_size = c10::ShapeSymbol::fromStaticSize(input1_value);
1198 break;
1199 case ::c10::onnx::LSTM:
1200 hidden_size = c10::ShapeSymbol::fromStaticSize(input1_value / 4);
1201 break;
1202 case ::c10::onnx::GRU:
1203 hidden_size = c10::ShapeSymbol::fromStaticSize(input1_value / 3);
1204 break;
1205 default:
1206 throw std::runtime_error(
1207 std::string() + "This is not a valid TimeSeries Node with type " +
1208 n->kind().toDisplayString());
1209 }
1210 } else {
1211 hidden_size = c10::ShapeSymbol::newSymbol();
1212 }
1213 }
1214
1215 if (n->outputs().size() > 1) {
1216 std::vector<c10::ShapeSymbol> final_shape = {
1217 seq_length, num_directions, batch_size, hidden_size};
1218 UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
1219 }
1220 for (const auto idx : c10::irange(2U, 4U)) {
1221 if (n->outputs().size() > idx) {
1222 std::vector<c10::ShapeSymbol> final_shape = {
1223 num_directions, batch_size, hidden_size};
1224 UpdateShape(n->output(idx - 1), c10::SymbolicShape(final_shape));
1225 }
1226 }
1227 }
1228
ProcessUnsqueezeNode(Node * n)1229 void ProcessUnsqueezeNode(Node* n) {
1230 TensorTypePtr output_type = n->output(0)->type()->cast<TensorType>();
1231 if (output_type == nullptr) {
1232 return;
1233 }
1234 if (output_type->dim().has_value() && output_type->dim().value() == 1 &&
1235 ConstantValueMap::HasShapeValue(n->input(0)->debugName())) {
1236 auto shape_value =
1237 ConstantValueMap::GetShapeValue(n->input(0)->debugName()).value();
1238 // When the scalar represents a shape, it is the same as the shape value
1239 // when it gets unsqueezed.
1240 ConstantValueMap::SetShapeValue(n->output()->debugName(), shape_value);
1241 }
1242 }
1243
1244 // As an addition to onnx shape inference, this function leverages constant
1245 // folding and a per-Op based process to update rank/shape for the graph, also
1246 // it update ConstantValueMap accordingly.
ComputeConstant(Node * n,int opset_version)1247 void ComputeConstant(Node* n, int opset_version) {
1248 if (n->kind() == ::c10::onnx::Constant) {
1249 if (n->kindOf(attr::value) == AttributeKind::t) {
1250 at::Tensor const_val = n->t(attr::value);
1251 at::Tensor const_val_copy =
1252 at::empty(const_val.sizes(), const_val.options());
1253 const_val_copy.copy_(const_val);
1254 ConstantValueMap::SetValue(n->output()->debugName(), const_val_copy);
1255 }
1256 return;
1257 }
1258 auto only_rank_available = false;
1259 size_t rank = 0;
1260
1261 // Constant folding.
1262 auto const_fold_val = ComputeConstantFolding(n, opset_version);
1263 if (const_fold_val.has_value()) {
1264 at::Tensor const_fold_val_copy = at::empty(
1265 const_fold_val.value().sizes(), const_fold_val.value().options());
1266 const_fold_val_copy.copy_(const_fold_val.value());
1267 ConstantValueMap::SetValue(n->output()->debugName(), const_fold_val_copy);
1268 UpdateShapeFromVector(n->output(), const_fold_val_copy.sizes().vec());
1269 return;
1270 }
1271
1272 switch (n->kind()) {
1273 case ::c10::onnx::Add:
1274 case ::c10::onnx::Div:
1275 case ::c10::onnx::Equal:
1276 case ::c10::onnx::Greater:
1277 case ::c10::onnx::GreaterOrEqual:
1278 case ::c10::onnx::Less:
1279 case ::c10::onnx::LessOrEqual:
1280 case ::c10::onnx::Mod:
1281 case ::c10::onnx::Mul:
1282 case ::c10::onnx::Pow:
1283 case ::c10::onnx::Sub: {
1284 ProcessBroadcastNode(n);
1285 break;
1286 }
1287 case ::c10::onnx::Shape: {
1288 auto input_shape =
1289 ConstantValueMap::GetShapeInto1DInt64Vector(n->input()->debugName());
1290 if (input_shape.has_value()) {
1291 auto shape_value = input_shape.value();
1292 // TODO: getDevice() ?
1293 auto options = c10::TensorOptions().dtype(at::kLong).device(at::kCPU);
1294 auto shape_value_size = static_cast<int64_t>(shape_value.size());
1295 auto f =
1296 at::from_blob(shape_value.data(), {shape_value_size}, at::kLong)
1297 .to(at::kCPU);
1298 // Need copy here
1299 at::Tensor f_copy = at::empty({shape_value_size}, options);
1300 f_copy.copy_(f);
1301 ConstantValueMap::SetValue(n->output()->debugName(), f_copy);
1302 std::vector<::c10::ShapeSymbol> final_shape_vector(
1303 1, c10::ShapeSymbol::fromStaticSize(shape_value_size));
1304 ::c10::SymbolicShape final_shape(final_shape_vector);
1305 UpdateShape(n->output(), final_shape);
1306 }
1307 break;
1308 }
1309 case ::c10::onnx::Reshape: {
1310 ProcessReshapeNode(n, opset_version);
1311 break;
1312 }
1313 case ::c10::onnx::Transpose: {
1314 if (n->hasAttributeS("perm")) {
1315 auto perm_v = n->is(attr::perm);
1316 rank = perm_v.size();
1317 auto is_default_perm = false;
1318 if (rank == 2 && perm_v[0] == 1 && perm_v[1] == 0) {
1319 is_default_perm = true;
1320 }
1321 auto shape_updated = false;
1322 if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
1323 auto shape_size_0 =
1324 ConstantValueMap::GetShape(n->input(0)->debugName())
1325 .value()
1326 .sizes();
1327 if (shape_size_0.has_value()) {
1328 auto shape_vector_0 = shape_size_0.value();
1329 std::vector<::c10::ShapeSymbol> final_shape_vector(
1330 shape_vector_0.size(), ::c10::ShapeSymbol());
1331 if (is_default_perm) {
1332 std::reverse_copy(
1333 std::begin(shape_vector_0),
1334 std::end(shape_vector_0),
1335 std::begin(final_shape_vector));
1336 } else {
1337 for (const auto i : c10::irange(shape_vector_0.size())) {
1338 final_shape_vector[i] = shape_vector_0[perm_v[i]];
1339 }
1340 }
1341 ::c10::SymbolicShape final_shape(final_shape_vector);
1342 UpdateShape(n->output(), final_shape);
1343 shape_updated = true;
1344 }
1345 }
1346 if (!shape_updated) {
1347 if (!is_default_perm) {
1348 only_rank_available = true;
1349 } else if (ConstantValueMap::HasRank(n->input(0)->debugName())) {
1350 rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value();
1351 only_rank_available = true;
1352 }
1353 }
1354 }
1355 break;
1356 }
1357 case ::c10::onnx::Concat: {
1358 ProcessConcatNode(n);
1359 break;
1360 }
1361 case ::c10::onnx::ConstantOfShape: {
1362 if (ConstantValueMap::HasValue(n->input()->debugName())) {
1363 auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
1364 n->input()->debugName());
1365 UpdateShapeFromVector(n->output(), shape_temp);
1366 if (!shape_temp.empty()) {
1367 if (n->hasAttributeS("value")) {
1368 auto value = n->t(attr::value).repeat(shape_temp);
1369 ConstantValueMap::SetValue(n->output()->debugName(), value);
1370 } else {
1371 auto options =
1372 c10::TensorOptions().dtype(at::kFloat).device(at::kCPU);
1373 auto value = at::full({1}, 0.0, options).repeat(shape_temp);
1374 ConstantValueMap::SetValue(n->output()->debugName(), value);
1375 }
1376 }
1377 }
1378 break;
1379 }
1380 case ::c10::onnx::Expand: {
1381 if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
1382 auto input0_shape_size =
1383 ConstantValueMap::GetShape(n->input(0)->debugName())
1384 .value()
1385 .sizes();
1386 if (input0_shape_size.has_value()) {
1387 auto input0_shape_value = input0_shape_size.value();
1388 if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
1389 // When value of `shape` is statically known,
1390 // output shape can be computed.
1391 auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
1392 n->input(1)->debugName());
1393 auto final_shape =
1394 ComputeShapeFromExpand(input0_shape_value, shape_temp);
1395 if (final_shape.has_value()) {
1396 UpdateShape(n->output(), final_shape.value());
1397 }
1398 } else if (
1399 auto expand_shape =
1400 ConstantValueMap::GetShapeInto1DInt64VectorWithOneUnknown(
1401 n->input(1)->debugName())) {
1402 // When shape of `shape` is statically known,
1403 // output rank can be computed.
1404 TORCH_INTERNAL_ASSERT(
1405 expand_shape.value().size() == 1,
1406 "`Shape` input to `Expand` should be a 1-D tensor. Instead got rank ",
1407 expand_shape.value().size());
1408 if (expand_shape.value()[0] > 0) {
1409 std::vector<c10::ShapeSymbol> final_shape;
1410 std::generate_n(
1411 std::back_inserter(final_shape),
1412 expand_shape.value()[0],
1413 ::c10::ShapeSymbol::newSymbol);
1414 UpdateShape(n->output(), c10::SymbolicShape(final_shape));
1415 }
1416 }
1417 }
1418 }
1419 break;
1420 }
1421 case ::c10::onnx::NonZero: {
1422 if (ConstantValueMap::HasRank(n->input()->debugName())) {
1423 auto rank = ConstantValueMap::GetRank(n->input()->debugName()).value();
1424 std::vector<c10::ShapeSymbol> dims;
1425 dims.emplace_back(
1426 c10::ShapeSymbol::fromStaticSize(static_cast<int64_t>(rank)));
1427 auto input_node = n->input()->node();
1428 if (input_node->kind() == ::c10::onnx::ConstantOfShape) {
1429 if (input_node->hasAttributeS("value")) {
1430 auto value =
1431 input_node->t(attr::value).toType(at::ScalarType::Float);
1432 auto value_a = value.accessor<float, 1>();
1433 if (value_a.size(0) == 1 && std::abs(value_a[0]) > 1e-6) {
1434 if (ConstantValueMap::HasShape(n->input()->debugName())) {
1435 auto shape_size_0 =
1436 ConstantValueMap::GetShape(n->input()->debugName()).value();
1437 if (shape_size_0.isComplete()) {
1438 auto shape_vector_0 = shape_size_0.sizes().value();
1439 int64_t num_elements = 1;
1440 for (auto cur_dim : shape_vector_0) {
1441 num_elements *= cur_dim.static_size();
1442 }
1443 dims.emplace_back(c10::ShapeSymbol::fromStaticSize(
1444 static_cast<int64_t>(num_elements)));
1445 }
1446 }
1447 }
1448 }
1449 }
1450 if (dims.size() == 1) {
1451 dims.emplace_back(c10::ShapeSymbol::newSymbol());
1452 }
1453 c10::SymbolicShape shape_v(dims);
1454 UpdateShape(n->output(), shape_v);
1455 }
1456 break;
1457 }
1458 case ::c10::onnx::MatMul: {
1459 ProcessMatMulNode(n);
1460 break;
1461 }
1462 case ::c10::onnx::ReduceMean:
1463 case ::c10::onnx::ReduceProd: {
1464 ProcessReduceNode(n);
1465 break;
1466 }
1467 case ::c10::onnx::RNN:
1468 case ::c10::onnx::LSTM:
1469 case ::c10::onnx::GRU: {
1470 ProcessTimeSeriesNode(n);
1471 break;
1472 }
1473 case ::c10::onnx::Size: {
1474 if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
1475 auto input0_shape_size =
1476 ConstantValueMap::GetShape(n->input(0)->debugName())
1477 .value()
1478 .sizes();
1479 if (input0_shape_size.has_value()) {
1480 auto input0_shape_value = input0_shape_size.value();
1481 int64_t total_size = 1;
1482 auto is_full_static = true;
1483 for (const auto i : c10::irange(input0_shape_value.size())) {
1484 if (input0_shape_value[i].is_static()) {
1485 total_size *= input0_shape_value[i].static_size();
1486 } else {
1487 is_full_static = false;
1488 break;
1489 }
1490 }
1491 if (is_full_static) {
1492 auto f_final = onnx_constant_fold::IntToTensor(total_size);
1493 ConstantValueMap::SetValue(n->output(0)->debugName(), f_final);
1494 }
1495 }
1496 }
1497 break;
1498 }
1499 case ::c10::onnx::Slice: {
1500 ProcessSliceNode(n, opset_version);
1501 break;
1502 }
1503 case ::c10::onnx::Cast:
1504 case ::c10::onnx::Relu:
1505 case ::c10::onnx::Softmax: {
1506 ProcessUnchangeNode(n);
1507 break;
1508 }
1509 case ::c10::onnx::Tile: {
1510 if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
1511 auto input0_shape_size =
1512 ConstantValueMap::GetShape(n->input(0)->debugName())
1513 .value()
1514 .sizes();
1515 if (input0_shape_size.has_value()) {
1516 auto input0_shape_value = input0_shape_size.value();
1517 if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
1518 auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
1519 n->input(1)->debugName());
1520 auto final_shape =
1521 ComputeShapeFromTile(input0_shape_value, shape_temp);
1522 if (final_shape.has_value()) {
1523 UpdateShape(n->output(), final_shape.value());
1524 }
1525 }
1526 }
1527 }
1528 break;
1529 }
1530 case ::c10::onnx::Unsqueeze: {
1531 ProcessUnsqueezeNode(n);
1532 break;
1533 }
1534 default: {
1535 break;
1536 }
1537 }
1538 if (n->outputs().size() > 1 ||
1539 ConstantValueMap::HasShape(n->output(0)->debugName())) {
1540 return;
1541 }
1542 if (only_rank_available) {
1543 UpdateRank(n->output(), rank);
1544 }
1545 }
1546
IsListConstructIntType(const Value * v)1547 bool IsListConstructIntType(const Value* v) {
1548 if (v->node()->kind() == prim::ListConstruct) {
1549 auto listType = v->node()->output()->type();
1550 auto containedType = listType->containedTypes().at(0);
1551 if (containedType == IntType::get()) {
1552 return true;
1553 }
1554 }
1555 return false;
1556 }
1557
1558 // Check if all graph inputs are static and allow a cached value to return.
1559 // Since this traverses all inputs of the graph (including weights), it can be
1560 // costly for large graphs. Since this is called for each node in an export,
1561 // and the inputs remain unchanged, we can cut down export time by caching.
AllGraphInputsStaticWithCaching(const Graph * g)1562 bool AllGraphInputsStaticWithCaching(const Graph* g) {
1563 auto maybe_is_static = ConstantValueMap::GetAllGraphInputsStatic();
1564 if (maybe_is_static.has_value()) {
1565 return maybe_is_static.value();
1566 } else {
1567 bool ret = AllGraphInputsStatic(g);
1568 ConstantValueMap::SetAllGraphInputsStatic(ret);
1569 return ret;
1570 }
1571 }
1572
ProcessConstantValueMap(Node * n,int opset_version)1573 void ProcessConstantValueMap(Node* n, int opset_version) {
1574 // Update ConstantValueMap on node outputs from onnx shape inference
1575 // For outputs, only update static shapes. For input, we update symbolic
1576 // shapes also. ONNX If can have different types on different branches, skip
1577 // here.
1578
1579 // Update the shape reliability for each node before processing
1580 // ConstantValueMap to prevent unreliable nodes from producing static
1581 // shapes
1582 UpdateReliable(n);
1583
1584 auto static_input_shape = AllGraphInputsStaticWithCaching(n->owningGraph());
1585 for (auto i : c10::irange(n->outputs().size())) {
1586 if (TensorTypePtr output_type = n->output(i)->type()->cast<TensorType>()) {
1587 if (output_type->dim().has_value()) {
1588 size_t rank = static_cast<size_t>(output_type->dim().value());
1589 ConstantValueMap::SetRank(n->output(i)->debugName(), rank);
1590 auto shape = output_type->symbolic_sizes();
1591 if (shape.isComplete()) {
1592 UpdateShape(n->output(i), shape);
1593 }
1594 }
1595 }
1596 }
1597 // Update ConstantValueMap on node inputs from onnx shape inference.
1598 // ListConstruct is handled here (we only consider IntType, not TensorType) ,
1599 // no need to have a per-op based process.
1600 for (auto i : c10::irange(n->inputs().size())) {
1601 if (TensorTypePtr input_type = n->input(i)->type()->cast<TensorType>()) {
1602 if (input_type->dim().has_value()) {
1603 size_t rank = static_cast<size_t>(input_type->dim().value());
1604 ConstantValueMap::SetRank(n->input(i)->debugName(), rank);
1605 // Only update shape if the input is onnx node.
1606 // If it is aten operators, for example,
1607 // Float(20, 20, strides=[1, 0], requires_grad=0, device=cpu),
1608 // %399 : Float(20, 20, strides=[0, 1], requires_grad=0, device=cpu)
1609 // = prim::ListUnpack(%397)
1610 // The tracer shape may not be correct when dynamic_axes is enabled.
1611 if (n->input(i)->node()->kind().is_onnx() || static_input_shape) {
1612 auto shape = input_type->symbolic_sizes();
1613 if (!ConstantValueMap::HasShape(n->input(i)->debugName())) {
1614 UpdateShape(n->input(i), shape);
1615 }
1616 }
1617 }
1618 } else if (IsListConstructIntType(n->input(i))) {
1619 auto lc_node = n->input(i)->node();
1620 auto rank = lc_node->inputs().size();
1621 auto lc_vector_optional = GetValueFromListConstructNode(lc_node);
1622 if (lc_vector_optional.has_value()) {
1623 auto lc_vector = lc_vector_optional.value();
1624 auto options = c10::TensorOptions().dtype(at::kLong).device(at::kCPU);
1625 auto lc_vector_size = static_cast<int64_t>(lc_vector.size());
1626 auto f = at::from_blob(lc_vector.data(), {lc_vector_size}, at::kLong)
1627 .to(at::kCPU);
1628 // Need copy here
1629 at::Tensor f_copy = at::empty({lc_vector_size}, options);
1630 f_copy.copy_(f);
1631 ConstantValueMap::SetValue(n->input(i)->debugName(), f_copy);
1632 UpdateShapeFromVector(n->input(i), {lc_vector_size});
1633 } else {
1634 UpdateShapeFromVector(n->input(i), {static_cast<int64_t>(rank)});
1635 }
1636 SetShapeValueFromListConstructNode(lc_node);
1637 }
1638 }
1639 // Additional logic to update the graph and ConstantValueMap
1640 ComputeConstant(n, opset_version);
1641 }
1642
1643 // Any additional post process that are specific to individual node kind.
SpecialPostProcess(Node * n)1644 void SpecialPostProcess(Node* n) {
1645 switch (n->kind()) {
1646 case ::c10::onnx::SequenceInsert: {
1647 // Special case when input sequence to SequenceInsert is empty.
1648 // onnx Sequence type requires element type to be set.
1649 // If the list to insert is empty, we set the elem type by
1650 // looking at the tensor being inserted.
1651 auto seq_node = n->input(0)->node();
1652 auto t_type = n->input(1)->type()->cast<TensorType>();
1653
1654 auto update_sequence_empty_dtype = [](Node* n,
1655 const TensorTypePtr& t_type) {
1656 TORCH_INTERNAL_ASSERT(n && n->kind() == ::c10::onnx::SequenceEmpty);
1657 TORCH_INTERNAL_ASSERT(t_type && t_type->scalarType().has_value());
1658 auto scalar_type = t_type->scalarType().value();
1659 auto onnx_type = ATenTypeToOnnxType(scalar_type);
1660 n->i_(attr::dtype, onnx_type);
1661 n->output()->setType(ListType::create(t_type));
1662 };
1663
1664 auto find_sequence_empty = [](Value* input,
1665 TensorTypePtr t_type) -> Node* {
1666 auto find_sequence_empty_impl =
1667 [](Value* input,
1668 TensorTypePtr t_type,
1669 auto& find_sequence_empty_ref) -> Node* {
1670 auto input_node = input->node();
1671 TORCH_INTERNAL_ASSERT(input_node);
1672
1673 // 1. Input is from SequenceEmpty.
1674 if (input_node->kind() == ::c10::onnx::SequenceEmpty) {
1675 return input_node;
1676 }
1677
1678 // 2. Input is subblock input of a Loop node, which takes outer block
1679 // SequenceEmpty as input.
1680 if (input_node->kind() == prim::Param) {
1681 auto loop_n = input_node->owningBlock()->owningNode();
1682 if (nullptr == loop_n || loop_n->kind() != ::c10::onnx::Loop) {
1683 return nullptr;
1684 }
1685
1686 auto it = std::find(
1687 input_node->outputs().begin(),
1688 input_node->outputs().end(),
1689 input);
1690 auto idx = std::distance(input_node->outputs().begin(), it);
1691
1692 auto outer_block_node = loop_n->input(idx)->node();
1693 if (outer_block_node &&
1694 outer_block_node->kind() == ::c10::onnx::SequenceEmpty) {
1695 // Found SequenceEmpty
1696 input->setType(ListType::create(t_type));
1697 return outer_block_node;
1698 } else {
1699 // Outer block node still not SequenceEmpty, call recursively in
1700 // case of nested loop.
1701 auto found_n = find_sequence_empty_ref(
1702 loop_n->input(idx), t_type, find_sequence_empty_ref);
1703 if (found_n) {
1704 input->setType(ListType::create(t_type));
1705 }
1706 return found_n;
1707 }
1708 }
1709
1710 // Could not find source SequenceEmpty node.
1711 return nullptr;
1712 };
1713 return find_sequence_empty_impl(
1714 input, std::move(t_type), find_sequence_empty_impl);
1715 };
1716
1717 if (seq_node && t_type && t_type->scalarType()) {
1718 if (seq_node->kind() == ::c10::onnx::SequenceEmpty) {
1719 update_sequence_empty_dtype(seq_node, t_type);
1720 } else if (seq_node->kind() == prim::Param) {
1721 // Try to find original onnx::SequenceEmpty node in outer block.
1722 auto seq_empty_n = find_sequence_empty(n->input(0), t_type);
1723 if (seq_empty_n) {
1724 update_sequence_empty_dtype(seq_empty_n, t_type);
1725 }
1726 }
1727 n->output()->setType(ListType::create(t_type));
1728 }
1729
1730 break;
1731 }
1732 case ::c10::onnx::Cast: {
1733 // ONNX shape inference is not able to assign output tensor shape,
1734 // when input to onnx::Cast has incomplete tensor shape, for example
1735 // missing shape, rank, dtype, etc. This postprocess sets the correct
1736 // dtype for output tensor, since the dtype info is stored in Cast
1737 // attribute.
1738 TensorTypePtr t_type = n->output()->type()->cast<TensorType>();
1739 if (nullptr != t_type && !t_type->scalarType().has_value()) {
1740 auto onnx_dtype = n->i(attr::to);
1741 auto aten_dtype = ONNXTypeToATenType(onnx_dtype);
1742 n->output()->setType(t_type->withScalarType(aten_dtype));
1743 }
1744 break;
1745 }
1746 case ::c10::onnx::ConstantOfShape: {
1747 // ONNX shape inference is not able to propagate output tensor shape
1748 // for onnx::ConstantOfShape if input `shape` is not constant.
1749 // This is a temporary solution when some partial information is
1750 // available, for example, knowing rank of output tensor, or knowing
1751 // symbolic shape. This solution won't be needed once we have proper
1752 // symbolic propagation.
1753 auto shape_node = n->input(0)->node();
1754 if (shape_node->kind() == ::c10::onnx::Shape) {
1755 // Shape -> ConstantOfShape
1756 auto orig_type = shape_node->input()->type()->cast<TensorType>();
1757 auto v_type = n->output()->type()->cast<TensorType>();
1758 if (v_type && !v_type->sizes().concrete_sizes()) {
1759 if (orig_type && orig_type->dim()) {
1760 // Assign symbolic shape of original input of onnx::Shape.
1761 v_type = v_type->withSymbolicShapes(orig_type->symbolic_sizes());
1762 n->output()->setType(v_type);
1763 } else if (
1764 shape_node->input()->node()->kind() ==
1765 ::c10::prim::ListConstruct) {
1766 // Assign rank of original input of onnx::Shape.
1767 v_type = v_type->withSizes({static_cast<int64_t>(
1768 shape_node->input()->node()->inputs().size())});
1769 n->output()->setType(v_type);
1770 }
1771 }
1772 } else if (shape_node->kind() == ::c10::prim::ListConstruct) {
1773 // ListConstruct -> ConstantOfShape
1774 auto v_type = n->output()->type()->cast<TensorType>();
1775 if (v_type && !v_type->sizes().concrete_sizes()) {
1776 auto value = n->t(attr::value);
1777 v_type = v_type->withScalarType(value.scalar_type());
1778 std::vector<c10::ShapeSymbol> sizes(
1779 shape_node->inputs().size(), c10::ShapeSymbol::newSymbol());
1780 v_type = v_type->withSymbolicShapes(c10::SymbolicShape(sizes));
1781 n->output()->setType(v_type);
1782 }
1783 }
1784 break;
1785 }
1786 case ::c10::onnx::If: {
1787 if (!IsValidONNXControlflowNode(n)) {
1788 break;
1789 }
1790 FixupONNXControlflowNodeOutputs(n);
1791 break;
1792 }
1793 case ::c10::onnx::Loop: {
1794 if (!IsValidONNXControlflowNode(n)) {
1795 break;
1796 }
1797 FixupONNXControlflowNodeOutputs(n);
1798 break;
1799 }
1800 }
1801 }
1802
UpdateOutputTypeByONNXProto(Node * n,Node * clone_node,const onnx::ModelProto & model_proto,SymbolDimMap & symbol_dim_map,DimSymbolMap & dim_symbol_map)1803 void UpdateOutputTypeByONNXProto(
1804 Node* n,
1805 Node* clone_node,
1806 const onnx::ModelProto& model_proto,
1807 SymbolDimMap& symbol_dim_map,
1808 DimSymbolMap& dim_symbol_map) {
1809 const auto& graph_proto = model_proto.graph();
1810
1811 // get data from value_info and updated original graph.
1812 const auto updateNodeOutputsByONNXValueInfo =
1813 [&](const onnx::ValueInfoProto& v_info) {
1814 for (size_t i = 0; i < n->outputs().size(); ++i) {
1815 if (clone_node->output(i)->debugName() == v_info.name()) {
1816 UpdateTorchValueByOnnxValueInfo(
1817 n->output(i), v_info, symbol_dim_map, dim_symbol_map);
1818 }
1819 }
1820 };
1821
1822 // Check graph outputs for inferred shapes.
1823 for (const auto i : c10::irange(graph_proto.output_size())) {
1824 updateNodeOutputsByONNXValueInfo(graph_proto.output(i));
1825 }
1826
1827 // Check value_infos for inferred shapes.
1828 for (const auto i : c10::irange(graph_proto.value_info_size())) {
1829 updateNodeOutputsByONNXValueInfo(graph_proto.value_info(i));
1830 }
1831 }
1832
FetchBlockInputMetadataFromParent(Block * b)1833 void FetchBlockInputMetadataFromParent(Block* b) {
1834 auto n = b->owningNode();
1835 if (nullptr != n && n->kind() == ::c10::onnx::Loop) {
1836 // Copy node input metadata to subgraph input.
1837 for (size_t i = 0; i < n->inputs().size(); ++i) {
1838 b->inputs().at(i)->setType(n->inputs().at(i)->type());
1839 }
1840 }
1841 }
1842
RemoveProcessedInputs(const Node * n)1843 void RemoveProcessedInputs(const Node* n) {
1844 // After processing a node for shape inference, remove intermediate tensors
1845 // that are stored in ConstantValueMap to reduce memory usage.
1846 // This will only remove tensors that are no longer needed by any other node.
1847
1848 // Returns whether a node was already processed for shape inference.
1849 const auto isNodeProcessed = [](const Node* node) {
1850 const auto& outputs = node->outputs();
1851 return std::any_of(outputs.begin(), outputs.end(), [](const Value* output) {
1852 // Assumes shape inference can at least determine the rank of the outputs.
1853 // If this assumption is wrong, some intermediate tensors will only be
1854 // deleted once shape inference is completed for the entire graph.
1855 return ConstantValueMap::HasRank(output->debugName());
1856 });
1857 };
1858
1859 // An input value is no longer needed if all of its consumer nodes
1860 // have already been processed.
1861 const auto isValueNoLongerNeeded = [isNodeProcessed](const Value* input) {
1862 const auto& uses = input->uses();
1863 return std::all_of(
1864 uses.begin(), uses.end(), [isNodeProcessed](const Use& use) {
1865 return isNodeProcessed(use.user);
1866 });
1867 };
1868
1869 for (const auto* input : n->inputs()) {
1870 if (ConstantValueMap::HasValue(input->debugName()) &&
1871 isValueNoLongerNeeded(input)) {
1872 ConstantValueMap::EraseValue(input->debugName());
1873 }
1874 }
1875 }
1876
ONNXShapeTypeInference(Block * b,const ParamMap & params_dict,int opset_version)1877 void ONNXShapeTypeInference(
1878 Block* b,
1879 const ParamMap& params_dict,
1880 int opset_version) {
1881 FetchBlockInputMetadataFromParent(b);
1882 auto valsToParamsMap = buildValueToParamsMap(b, params_dict);
1883 for (auto const& it : valsToParamsMap) {
1884 auto key = it.first;
1885 auto value = it.second;
1886 if (key->node()->kind() == prim::Param) {
1887 if (value.second.isTensor()) {
1888 ConstantValueMap::SetValue(value.first, value.second.toTensor());
1889 }
1890 } else if (key->node()->kind() == ::c10::onnx::Constant) {
1891 at::Tensor const_val = key->node()->t(attr::value);
1892 at::Tensor const_val_copy =
1893 at::empty(const_val.sizes(), const_val.options());
1894 const_val_copy.copy_(const_val);
1895 ConstantValueMap::SetValue(value.first, const_val_copy);
1896 } else {
1897 throw std::runtime_error(
1898 "ONNXShapeTypeInference - Unsupported kind of constant node found.");
1899 }
1900 }
1901 for (auto n : b->nodes()) {
1902 for (auto subblock : n->blocks()) {
1903 ONNXShapeTypeInference(subblock, params_dict, opset_version);
1904 }
1905 ONNXShapeTypeInference(n, params_dict, opset_version);
1906 RemoveProcessedInputs(n);
1907 }
1908 }
1909
1910 } // namespace
1911
1912 // For some operators, there are some inputs not related to shape inference.
1913 // For example, LSTM input 4 (sequence_lens) is optional,
1914 // and the shape inference can be done through other required inputs.
1915 // When we compute reliable, we don't need this input be reliable.
1916 static std::unordered_map<std::string, std::unordered_set<int64_t>>
1917 non_required_shape_inference_idx_map = {{"onnx::LSTM", {4}}};
1918
AllGraphInputsStatic(const Graph * g)1919 bool AllGraphInputsStatic(const Graph* g) {
1920 for (auto n : g->inputs()) {
1921 if (TensorTypePtr input_type = n->type()->cast<TensorType>()) {
1922 if (input_type->dim()) {
1923 auto shape = input_type->symbolic_sizes();
1924 if (!ConstantValueMap::HasShape(n->debugName())) {
1925 UpdateShapeConstantValueMap(n, shape);
1926 }
1927 }
1928 }
1929 }
1930 for (auto n : g->inputs()) {
1931 // Some inputs can be non-Tensor type, e.g.,
1932 // __torch__.torch.classes.quantized.LinearPackedParamsBase
1933 // so we only need check Tensor type here.
1934 if (n->type()->cast<TensorType>() && !n->isCompleteTensor()) {
1935 return false;
1936 }
1937 }
1938 return true;
1939 }
1940
AreInputsReliableOrStatic(Node * n)1941 std::pair<bool, bool> AreInputsReliableOrStatic(Node* n) {
1942 auto reliable = true;
1943 auto complete = true;
1944 auto input_size = n->inputs().size();
1945 std::unordered_set<int64_t> non_required_idx = {};
1946 if (non_required_shape_inference_idx_map.find(n->kind().toDisplayString()) !=
1947 non_required_shape_inference_idx_map.end()) {
1948 non_required_idx =
1949 non_required_shape_inference_idx_map[n->kind().toDisplayString()];
1950 }
1951 for (auto idx : c10::irange(input_size)) {
1952 if (!non_required_idx.empty() &&
1953 non_required_idx.find(idx) != non_required_idx.end()) {
1954 continue;
1955 }
1956 auto input = n->inputs()[idx];
1957 // Always consider None reliable and complete, because it represents
1958 // unspecified optional inputs in ONNX.
1959 if (input->node()->mustBeNone()) {
1960 continue;
1961 }
1962 reliable &=
1963 ConstantValueMap::GetTypeReliable(input->debugName()).value_or(false);
1964 if (auto pt = input->type()->cast<TensorType>()) {
1965 if (!pt->sizes().isComplete()) {
1966 complete = false;
1967 }
1968 }
1969 }
1970 return std::make_pair(reliable, complete);
1971 }
1972
1973 // There is no need to put onnx type here, but we need this
1974 // for some legacy tests when onnx_shape_inference=False.
1975 static std::unordered_set<std::string> nodeTypeReliableForTracer = {
1976 "prim::ListConstruct",
1977 "onnx::Cast",
1978 "onnx::Constant",
1979 "onnx::Relu",
1980 "com.microsoft::Gelu",
1981 "aten::ATen"};
1982
UpdateReliable(torch::jit::Value * output,const std::pair<bool,bool> & inferred_type_reliable,bool no_type_warning)1983 void UpdateReliable(
1984 torch::jit::Value* output,
1985 const std::pair<bool, bool>& inferred_type_reliable,
1986 bool no_type_warning) {
1987 auto inferred =
1988 ConstantValueMap::GetUseInferredType(output->debugName()).value_or(false);
1989 auto isTypeReliableForTracer =
1990 nodeTypeReliableForTracer.find(
1991 output->node()->kind().toDisplayString()) !=
1992 nodeTypeReliableForTracer.end();
1993 if (!inferred && !isTypeReliableForTracer &&
1994 !output->node()->kind().is_onnx() && no_type_warning) {
1995 TORCH_WARN(
1996 "The shape inference of ",
1997 output->node()->kind().toDisplayString(),
1998 " type is missing, so it may result in wrong shape inference for the exported graph. ",
1999 "Please consider adding it in symbolic function.");
2000 // Experimental, nothing sent to stdout nor stderr.
2001 diagnostics::Diagnose(
2002 diagnostics::Rule::kNodeMissingOnnxShapeInference,
2003 diagnostics::Level::kWarning,
2004 {{"op_name", output->node()->kind().toDisplayString()}});
2005 }
2006 auto reliable = false;
2007 if (inferred) {
2008 reliable = inferred_type_reliable.first;
2009 } else {
2010 if (inferred_type_reliable.second && isTypeReliableForTracer) {
2011 reliable = true;
2012 }
2013 }
2014 // Assume that the tracer can estimate rank correctly,
2015 // then the output tensor of Shape should always be reliable.
2016 if (output->node()->kind() == ::c10::onnx::Shape) {
2017 reliable = true;
2018 }
2019 ConstantValueMap::SetTypeReliable(output->debugName(), reliable);
2020 if (!reliable) {
2021 if (auto output_tensor_type = output->type()->cast<TensorType>()) {
2022 output->setType(output_tensor_type->withSymbolicShapes(
2023 ::c10::SymbolicShape(output_tensor_type->dim())));
2024 }
2025 }
2026 }
2027
UpdateReliable(Node * n)2028 void UpdateReliable(Node* n) {
2029 auto input_reliable = AreInputsReliableOrStatic(n);
2030 for (auto output : n->outputs()) {
2031 UpdateReliable(output, input_reliable);
2032 }
2033 }
2034
2035 // Traverse the graph inputs and compute reliability (e.g., are shapes static).
2036 // Since the inputs do not change during export, we save computation time by
2037 // marking it as computed and subsequently skipping.
SetGraphInputTypeReliable(const Graph * g)2038 void SetGraphInputTypeReliable(const Graph* g) {
2039 if (!ConstantValueMap::GetAllGraphInputsReliableComputed()) {
2040 for (auto graph_input : g->inputs()) {
2041 if (!ConstantValueMap::HasTypeReliable(graph_input->debugName())) {
2042 ConstantValueMap::SetTypeReliable(graph_input->debugName(), true);
2043 }
2044 }
2045 ConstantValueMap::SetAllGraphInputsReliableComputed(true);
2046 }
2047 }
2048
ONNXShapeTypeInference(Node * n,const ParamMap & params_dict,int opset_version)2049 void ONNXShapeTypeInference(
2050 Node* n,
2051 const ParamMap& params_dict,
2052 int opset_version) {
2053 std::unordered_map<std::string, std::string> torch_to_onnx_input;
2054 std::unordered_map<std::string, std::string> torch_to_onnx_output;
2055 auto& original_shape_data = ConstantValueMap::GetInferredShapeData();
2056 ShapeDataMap inferred_shape_data;
2057 auto& symbol_dim_map = ConstantValueMap::GetSymbolDimMap();
2058 auto& dim_symbol_map = ConstantValueMap::GetDimSymbolMap();
2059
2060 SetGraphInputTypeReliable(n->owningGraph());
2061 GRAPH_UPDATE(
2062 "Running ONNX shape inference for node: ", n->kind().toDisplayString());
2063
2064 if (IsValidONNXNode(n)) {
2065 // Create a Graph containing only the single node n.
2066 // This graph is later converted to ONNX to run shape inference.
2067 auto n_graph = std::make_shared<Graph>();
2068 auto clone_node = CloneNodeToGraph(n, n_graph, params_dict, opset_version);
2069 n_graph->insertNode(clone_node);
2070
2071 // Register all node outputs as graph outputs.
2072 for (auto output : clone_node->outputs()) {
2073 n_graph->registerOutput(output);
2074 }
2075
2076 // Map original PyTorch graph's i/o name
2077 // to temporal ONNX graph's i/o name for shape inference
2078 for (size_t i = 0; i < clone_node->inputs().size(); ++i) {
2079 torch_to_onnx_input[n->input(i)->debugName()] =
2080 clone_node->input(i)->debugName();
2081 }
2082
2083 for (size_t i = 0; i < clone_node->outputs().size(); ++i) {
2084 torch_to_onnx_output[n->output(i)->debugName()] =
2085 clone_node->output(i)->debugName();
2086 }
2087 // Make inferred_shape_data use name from temporal ONNX graph
2088 // instead of original PyTorch graph. Only copy what we need,
2089 // which are the inputs of n.
2090 for (auto input : n->inputs()) {
2091 const auto maybe_shape = original_shape_data.find(input->debugName());
2092 if (maybe_shape != original_shape_data.end()) {
2093 const auto onnx_output_name =
2094 torch_to_onnx_input.find(input->debugName());
2095 if (onnx_output_name != torch_to_onnx_input.end()) {
2096 inferred_shape_data[onnx_output_name->second] = maybe_shape->second;
2097 }
2098 }
2099 }
2100 // Use scalar_type_analysis without low precision cast
2101 ScalarTypeAnalysisForONNX(n_graph, false, opset_version);
2102
2103 GRAPH_DEBUG("Original torch graph: ", n->owningGraph()->toString());
2104 GRAPH_DEBUG(
2105 "Cloned torch graph to run shape inference: ", n_graph->toString());
2106
2107 if (IsGraphValidForInference(n_graph)) {
2108 // TODO: Some ops have conversion happen at Peephole pass.
2109 // The conversion here is incomplete for these ops.
2110 // e.g: ListConstruct, ListUnpack, etc.
2111 std::shared_ptr<onnx::ModelProto> model_proto;
2112 ConvertGraphToONNXProto(
2113 n_graph, model_proto, symbol_dim_map, dim_symbol_map, opset_version);
2114 GRAPH_DEBUG(
2115 "ONNX graph to run shape inference: ", prettyPrint(*model_proto));
2116
2117 // infer shape
2118 try {
2119 // TODO(#79208): Enable more operators to support data propagation
2120 switch (n->kind()) {
2121 case ::c10::onnx::Shape:
2122 case ::c10::onnx::Gather: {
2123 auto* schema_registry = onnx::OpSchemaRegistry::Instance();
2124 onnx::ShapeInferenceOptions options{
2125 /*check_type_val=*/false,
2126 /*strict_mode_val=*/0,
2127 /*data_prop_val=*/true};
2128 onnx::shape_inference::InferShapes(
2129 *model_proto, schema_registry, options, &inferred_shape_data);
2130 break;
2131 }
2132 default: {
2133 onnx::shape_inference::InferShapes(*model_proto);
2134 break;
2135 }
2136 }
2137 UpdateOutputTypeByONNXProto(
2138 n, clone_node, *model_proto, symbol_dim_map, dim_symbol_map);
2139 } catch (std::runtime_error& ex) {
2140 // TODO: include this as warning once we have a more consolidated
2141 // warning system.
2142 GRAPH_DEBUG(
2143 "ONNX shape inference fails with: ",
2144 ex.what(),
2145 " on graph: ",
2146 n_graph->toString());
2147 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2148 const char shape_err[] = "ShapeInferenceError";
2149 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2150 const char type_err[] = "TypeInferenceError";
2151 if ((strstr(ex.what(), shape_err) == nullptr) &&
2152 (strstr(ex.what(), type_err) == nullptr)) {
2153 throw;
2154 }
2155 }
2156 GRAPH_DEBUG(
2157 "ONNX graph after shape inference: ", prettyPrint(*model_proto));
2158 }
2159 } else if (CustomSettype(n)) {
2160 // If the node is not ONNX standard, go through every output to check if
2161 // they all have shape. If they all do, this should be reliable even if the
2162 // Op is not from ONNX.
2163 for (auto node_output : n->outputs()) {
2164 // Custom setType output should get in here if it's set correctly. They
2165 // will be updated to inferred for later updatereliable function.
2166 ConstantValueMap::SetUseInferredType(node_output->debugName(), true);
2167 }
2168 }
2169
2170 SpecialPostProcess(n);
2171 // Get data propagation result from ONNX shape inference
2172 for (const auto& output : n->outputs()) {
2173 const auto inferred_shape_pair =
2174 inferred_shape_data.find(torch_to_onnx_output[output->debugName()]);
2175 if (inferred_shape_pair != inferred_shape_data.end()) {
2176 const auto& inferred_shape = inferred_shape_pair->second;
2177 int rank = inferred_shape.dim_size();
2178 std::vector<::c10::ShapeSymbol> final_shape(rank);
2179 for (int i = 0; i < rank; ++i) {
2180 final_shape[i] = ONNXDimToShapeSymbol(
2181 inferred_shape.dim(i), symbol_dim_map, dim_symbol_map);
2182 }
2183 c10::SymbolicShape shape_value(final_shape);
2184 // Store data propagation result into shapeValueMap
2185 ConstantValueMap::SetShapeValue(output->debugName(), shape_value);
2186 // Use original name in PyTorch graph instead of
2187 // temporary name in intermediate ONNX graph
2188 // Add this back to original_shape_data
2189 original_shape_data[output->debugName()] = inferred_shape;
2190 }
2191 }
2192
2193 if (IsValidONNXNode(n)) {
2194 ProcessConstantValueMap(n, opset_version);
2195 if (n->kind() != prim::ListConstruct) {
2196 for (auto input : n->inputs()) {
2197 if (input->node()->kind() == prim::ListConstruct) {
2198 UpdateReliable(input, AreInputsReliableOrStatic(input->node()));
2199 }
2200 }
2201 }
2202 }
2203 UpdateReliable(n);
2204
2205 // For the node type that does not have ComputeConstant logic, it may have
2206 // reliable shape but its shape is not in ConstantValueMap. So we need this
2207 // logic to update ConstantValueMap.
2208 for (auto node_output : n->outputs()) {
2209 UpdateShapeConstantIfReliable(node_output);
2210 }
2211
2212 GRAPH_DEBUG(
2213 "Torch graph after shape inference:", n->owningGraph()->toString());
2214 }
2215
ONNXSetDynamicInputShape(std::shared_ptr<Graph> & graph,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes,const std::vector<std::string> & input_names)2216 void ONNXSetDynamicInputShape(
2217 std::shared_ptr<Graph>& graph,
2218 const std::unordered_map<
2219 std::string,
2220 std::unordered_map<int64_t, std::string>>& dynamic_axes,
2221 const std::vector<std::string>& input_names) {
2222 GRAPH_UPDATE("ONNX set dynamic input shape.");
2223 GRAPH_UPDATE("dynamic axes tensor names:", [&]() {
2224 std::vector<std::string> res(dynamic_axes.size());
2225 std::transform(
2226 dynamic_axes.begin(), dynamic_axes.end(), res.begin(), [](auto pair) {
2227 return pair.first;
2228 });
2229 return res;
2230 }());
2231
2232 std::map<std::string, ::c10::ShapeSymbol> name_to_sym;
2233
2234 for (const auto i : c10::irange(input_names.size())) {
2235 const auto& input_name = input_names[i];
2236 if (dynamic_axes.find(input_name) != dynamic_axes.end()) {
2237 auto axes_names = dynamic_axes.find(input_name)->second;
2238 TORCH_INTERNAL_ASSERT(i < graph->inputs().size());
2239 auto input_tensor_type = graph->inputs()[i]->type()->cast<TensorType>();
2240 if (!input_tensor_type) {
2241 continue;
2242 }
2243
2244 auto shape_ref = input_tensor_type->symbolic_sizes().sizes();
2245 TORCH_CHECK(
2246 shape_ref.has_value(), "Input tensor shape should have value.");
2247 auto shape = shape_ref.value();
2248
2249 for (const auto& pair : axes_names) {
2250 const auto axis = pair.first;
2251 const auto name = pair.second;
2252 if (name_to_sym.find(name) == name_to_sym.end()) {
2253 name_to_sym[name] = ::c10::ShapeSymbol::newSymbol();
2254 }
2255 TORCH_CHECK(
2256 axis < static_cast<int64_t>(shape.size()),
2257 "Dynamic shape axis should be no more than the shape dimension for ",
2258 name);
2259 shape[axis] = name_to_sym[name];
2260 }
2261
2262 graph->inputs()[i]->setType(
2263 input_tensor_type->withSymbolicShapes(::c10::SymbolicShape(shape)));
2264 }
2265 }
2266 }
2267
HasSequenceTypeOutput(Node * node)2268 bool HasSequenceTypeOutput(Node* node) {
2269 if (node->kind() == ::c10::onnx::SplitToSequence ||
2270 node->kind() == ::c10::onnx::SequenceInsert ||
2271 node->kind() == ::c10::onnx::SequenceEmpty ||
2272 node->kind() == ::c10::onnx::SequenceErase ||
2273 node->kind() == ::c10::onnx::SequenceConstruct ||
2274 node->kind() == ::c10::onnx::Loop || node->kind() == ::c10::onnx::If)
2275 return true;
2276 return false;
2277 }
2278
ONNXUpdateTypeFromTensor(Value * graph_output,const at::Tensor & output,bool onnx_shape_inference)2279 void ONNXUpdateTypeFromTensor(
2280 Value* graph_output,
2281 const at::Tensor& output,
2282 bool onnx_shape_inference) {
2283 if (onnx_shape_inference) {
2284 MergeInferredTypeAndSetMap(
2285 graph_output, TensorType::create(output), graph_output->type());
2286 } else {
2287 graph_output->inferTypeFrom(output);
2288 }
2289 }
2290
2291 // Recursively look into elements in `output_obj`, and assign shape/type info
2292 // into flattened graph outputs. `outputs_index` is passed in to point to the
2293 // current index in flattened graph outputs. The updated `outputs_index` is
2294 // returned at the end of the function.
ONNXAssignOutputShape(std::shared_ptr<Graph> & graph,size_t outputs_index,PyObject * output_obj,bool onnx_shape_inference,bool is_script,int opset_version)2295 size_t ONNXAssignOutputShape(
2296 std::shared_ptr<Graph>& graph,
2297 size_t outputs_index,
2298 PyObject* output_obj,
2299 bool onnx_shape_inference,
2300 bool is_script,
2301 int opset_version) {
2302 auto index_check = [&]() {
2303 TORCH_INTERNAL_ASSERT(
2304 outputs_index <= graph->outputs().size(),
2305 "Incorrect number of elements provided as example outputs.");
2306 };
2307
2308 index_check();
2309
2310 if (THPVariable_Check(output_obj)) {
2311 const at::Tensor& var = THPVariable_Unpack(output_obj);
2312 ONNXUpdateTypeFromTensor(
2313 graph->outputs().at(outputs_index), var, onnx_shape_inference);
2314 outputs_index++;
2315 } else if (PyTuple_Check(output_obj)) {
2316 size_t tuple_len = PyTuple_GET_SIZE(output_obj);
2317 for (const auto i : c10::irange(tuple_len)) {
2318 outputs_index = ONNXAssignOutputShape(
2319 graph,
2320 outputs_index,
2321 PyTuple_GET_ITEM(output_obj, i),
2322 onnx_shape_inference,
2323 is_script,
2324 opset_version);
2325 }
2326 } else if (PyList_Check(output_obj)) {
2327 const auto list_len = PyList_GET_SIZE(output_obj);
2328 if (HasSequenceTypeOutput(graph->outputs().at(outputs_index)->node())) {
2329 auto output_type = graph->outputs().at(outputs_index)->type();
2330 TORCH_CHECK(
2331 output_type->cast<ListType>(),
2332 "Expected a sequence type, but received a non-iterable type in graph output index ",
2333 outputs_index);
2334 if (list_len > 0) {
2335 auto list_elem = PyList_GET_ITEM(output_obj, 0);
2336 TORCH_INTERNAL_ASSERT(THPVariable_Check(list_elem));
2337 auto& var = THPVariable_Unpack(list_elem);
2338 for (const auto i : c10::irange(1, list_len)) {
2339 list_elem = PyList_GET_ITEM(output_obj, i);
2340 TORCH_INTERNAL_ASSERT(THPVariable_Check(list_elem));
2341 auto& new_var = THPVariable_Unpack(list_elem);
2342 TORCH_CHECK(
2343 var.scalar_type() == new_var.scalar_type(),
2344 "Unsupported sequence with mixed element types in model outputs. "
2345 "ONNX supports only sequences of elements of the same data type.");
2346 }
2347 auto elem_type = graph->outputs()
2348 .at(outputs_index)
2349 ->type()
2350 ->castRaw<ListType>()
2351 ->getElementType()
2352 ->cast<TensorType>();
2353 elem_type = elem_type->withScalarType(var.scalar_type());
2354 auto graph_output = graph->outputs().at(outputs_index);
2355 MergeInferredTypeAndSetMap(
2356 graph_output, graph_output->type(), ListType::create(elem_type));
2357 } else {
2358 graph->outputs()
2359 .at(outputs_index)
2360 ->setType(graph->outputs().at(outputs_index)->type());
2361 }
2362 outputs_index++;
2363 } else {
2364 // When torch output is a list type, but ONNX node is not a
2365 // sequence type. Like prim::ListConstruct
2366 for (const auto i : c10::irange(list_len)) {
2367 outputs_index = ONNXAssignOutputShape(
2368 graph,
2369 outputs_index,
2370 PyList_GET_ITEM(output_obj, i),
2371 onnx_shape_inference,
2372 is_script,
2373 opset_version);
2374 }
2375 }
2376 } else if (PyDict_Check(output_obj)) {
2377 // Support for dict data type is limited to fixed size dictionaries in
2378 // ONNX.
2379 // Dictionary values are unrolled and keys are not preserved.
2380 auto* items = PyDict_Items(output_obj);
2381 auto unrolled_dict = py::reinterpret_borrow<py::list>(items);
2382 TORCH_INTERNAL_ASSERT(PyList_Check(unrolled_dict.ptr()));
2383 for (const auto i : c10::irange(unrolled_dict.size())) {
2384 outputs_index = ONNXAssignOutputShape(
2385 graph,
2386 outputs_index,
2387 PyList_GET_ITEM(unrolled_dict.ptr(), i),
2388 onnx_shape_inference,
2389 is_script,
2390 opset_version);
2391 }
2392 Py_DECREF(items);
2393 } else if (THPUtils_checkString(output_obj)) {
2394 // Ignore string, since they are not supported as output in ONNX.
2395 } else if (PyNone_Check(output_obj)) {
2396 // Tracing:
2397 // Ignore None, since it is not captured in IR graph as output.
2398 // Scripting:
2399 // Ignore None, if observing a fixed `None` node in IR graph. Because
2400 // it is meaningless to include it as graph output as it carries no
2401 // data/information. Plus that static `None` is not supported in ONNX
2402 // IR. Otherwise, the output should have type `Optional`, and should be
2403 // converted to ONNX `Optional`.
2404
2405 // More context:
2406 // Cause: in tracing we flatten the outputs in ONNXTracedModule.forward
2407 // in torch/jit/_trace.py while tracing. This means the traced IR graph
2408 // has None outputs omitted.
2409 // But then the outputs passed in here are un-flattened, which means they
2410 // contain None objects. Ideally we'd remove this difference.
2411 if (is_script && outputs_index < graph->outputs().size()) {
2412 if (graph->outputs().at(outputs_index)->node()->mustBeNone()) {
2413 if (opset_version >= 15) {
2414 ReplaceGraphOutputNoneWithOptional(graph, outputs_index);
2415 outputs_index++;
2416 } else {
2417 graph->eraseOutput(outputs_index);
2418 }
2419 } else {
2420 outputs_index++;
2421 }
2422 }
2423 } else {
2424 std::string msg =
2425 ("Model output has unsupported type. See "
2426 "https://pytorch.org/docs/stable/onnx.html#types. Got type: ");
2427 msg += THPUtils_typename(output_obj);
2428 throw std::runtime_error(msg);
2429 }
2430
2431 index_check();
2432
2433 return outputs_index;
2434 }
2435
ONNXOptionalNodeForNone(std::shared_ptr<Graph> & graph)2436 Node* ONNXOptionalNodeForNone(std::shared_ptr<Graph>& graph) {
2437 TypePtr elem_type = TensorType::get()->withScalarType(at::ScalarType::Float);
2438 Node* opt_node = graph->create(::c10::onnx::Optional, 1);
2439 opt_node->ty_(Symbol::attr("type"), elem_type);
2440 opt_node->output()->setType(OptionalType::create(elem_type));
2441 return opt_node;
2442 }
2443
ReplaceGraphOutputNoneWithOptional(std::shared_ptr<Graph> & graph,size_t outputs_index)2444 void ReplaceGraphOutputNoneWithOptional(
2445 std::shared_ptr<Graph>& graph,
2446 size_t outputs_index) {
2447 Node* opt_node = ONNXOptionalNodeForNone(graph);
2448 opt_node->insertBefore(graph->return_node());
2449 Value* graph_output = graph->outputs().at(outputs_index);
2450 // replace only the last value as Optional type only affects
2451 // the value right before output
2452 graph_output->replaceAllUsesAfterNodeWith(opt_node, opt_node->output());
2453 if (!graph_output->type()->cast<NoneType>()) {
2454 opt_node->addInput(graph_output);
2455 opt_node->copyMetadata(graph_output->node());
2456 }
2457 }
2458
ONNXAssignOutputShape(std::shared_ptr<Graph> & graph,at::ArrayRef<at::Tensor> outputs,const python::IODescriptor & desc,bool onnx_shape_inference,bool is_script,int opset_version)2459 void ONNXAssignOutputShape(
2460 std::shared_ptr<Graph>& graph,
2461 at::ArrayRef<at::Tensor> outputs,
2462 const python::IODescriptor& desc,
2463 bool onnx_shape_inference,
2464 bool is_script,
2465 int opset_version) {
2466 size_t outputs_index = 0;
2467 PyObject* py_obj = unflatten(outputs, desc);
2468 TORCH_INTERNAL_ASSERT(PyTuple_Check(py_obj));
2469
2470 outputs_index = ONNXAssignOutputShape(
2471 graph,
2472 outputs_index,
2473 py_obj,
2474 onnx_shape_inference,
2475 is_script,
2476 opset_version);
2477
2478 TORCH_INTERNAL_ASSERT(
2479 outputs_index == graph->outputs().size(),
2480 "Incorrect number of elements provided as example outputs.");
2481
2482 Py_DECREF(py_obj);
2483 GRAPH_DUMP("After ONNXAssignOutputShape", graph);
2484 }
2485
ONNXShapeTypeInference(std::shared_ptr<Graph> & graph,const ParamMap & params_dict,int opset_version)2486 void ONNXShapeTypeInference(
2487 std::shared_ptr<Graph>& graph,
2488 const ParamMap& params_dict,
2489 int opset_version) {
2490 ConstantValueMap::ClearMaps();
2491 SetGraphInputTypeReliable(graph.get());
2492 ONNXShapeTypeInference(graph->block(), params_dict, opset_version);
2493 ConstantValueMap::ClearMaps();
2494 }
2495
UpdateShapeConstantIfReliable(torch::jit::Value * node_output)2496 void UpdateShapeConstantIfReliable(torch::jit::Value* node_output) {
2497 if (ConstantValueMap::HasTypeReliable(node_output->debugName())) {
2498 auto reliable = ConstantValueMap::GetTypeReliable(node_output->debugName())
2499 .value_or(false);
2500 if (reliable && !ConstantValueMap::HasShape(node_output->debugName())) {
2501 // TODO: ListType case
2502 if (auto output_tensor_type = node_output->type()->cast<TensorType>()) {
2503 if (output_tensor_type->dim()) {
2504 auto symbolic_sizes = output_tensor_type->symbolic_sizes();
2505 UpdateShapeConstantValueMap(node_output, symbolic_sizes);
2506 }
2507 }
2508 }
2509 }
2510 }
2511
2512 } // namespace torch::jit
2513