xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/helper.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/passes/onnx/helper.h>
3 #include <torch/csrc/onnx/back_compat.h>
4 
5 #include <ATen/ScalarOps.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #else
10 #include <ATen/ops/unsqueeze.h>
11 #endif
12 
13 #include <onnx/onnx_pb.h>
14 
15 namespace torch::jit {
16 namespace onnx {
17 using namespace ::c10::onnx;
18 
19 } // namespace onnx
20 
buildValueToParamsMap(Block * b,const ParamMap & paramsDict)21 ValueToParamPairMap buildValueToParamsMap(
22     Block* b,
23     const ParamMap& paramsDict) {
24   ValueToParamPairMap valsToParamsMap;
25   for (auto& input : b->inputs()) {
26     auto it = paramsDict.find(input->debugName());
27     if (it != paramsDict.end()) {
28       valsToParamsMap.emplace(input, *it);
29     }
30   }
31   return valsToParamsMap;
32 }
33 
eraseUnusedBlockInputs(Block * b)34 void eraseUnusedBlockInputs(Block* b) {
35   for (size_t i_1 = b->inputs().size(); i_1 > 0; --i_1) {
36     size_t i = i_1 - 1;
37     if (!b->inputs().at(i)->hasUses()) {
38       b->eraseInput(i);
39     }
40   }
41 }
42 
eraseUnusedValuesFromMap(ValueToParamPairMap & valsToParamsMap)43 void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap) {
44   auto it = valsToParamsMap.begin();
45   while (it != valsToParamsMap.end()) {
46     if (!it->first->hasUses()) {
47       it = valsToParamsMap.erase(it);
48     } else {
49       ++it;
50     }
51   }
52 }
53 
buildParamsMapFromValueToParamsMap(const ValueToParamPairMap & valsToParamsMap,ParamMap & paramsDict)54 void buildParamsMapFromValueToParamsMap(
55     const ValueToParamPairMap& valsToParamsMap,
56     ParamMap& paramsDict) {
57   paramsDict.clear();
58   for (const auto& nameTensorParamPair : valsToParamsMap) {
59     paramsDict.insert(nameTensorParamPair.second);
60   }
61 }
62 
ONNXTypeToATenType(int32_t onnx_type)63 std::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type) {
64   switch (onnx_type) {
65     case ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED:
66       return at::ScalarType::Undefined;
67     case ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
68       return at::kFloat;
69     case ::ONNX_NAMESPACE::TensorProto_DataType_UINT8:
70       return at::kByte;
71     case ::ONNX_NAMESPACE::TensorProto_DataType_INT8:
72       return at::kChar;
73     case ::ONNX_NAMESPACE::TensorProto_DataType_INT16:
74       return at::kShort;
75     case ::ONNX_NAMESPACE::TensorProto_DataType_INT32:
76       return at::kInt;
77     case ::ONNX_NAMESPACE::TensorProto_DataType_INT64:
78       return at::kLong;
79     case ::ONNX_NAMESPACE::TensorProto_DataType_BOOL:
80       return at::kBool;
81     case ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
82       return at::kHalf;
83     case ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
84       return at::kDouble;
85     case ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64:
86       return at::kComplexFloat;
87     case ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128:
88       return at::kComplexDouble;
89     case ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
90       return at::kBFloat16;
91     case ::torch::onnx::TensorProto_DataType_FLOAT8E5M2:
92       return at::kFloat8_e5m2;
93     case ::torch::onnx::TensorProto_DataType_FLOAT8E5M2FNUZ:
94       return at::kFloat8_e5m2fnuz;
95     case ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN:
96       return at::kFloat8_e4m3fn;
97     case ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FNUZ:
98       return at::kFloat8_e4m3fnuz;
99     default:
100       TORCH_CHECK(
101           false,
102           "ONNX type ",
103           onnx_type,
104           " is an unexpected tensor scalar type");
105   }
106   return std::optional<at::ScalarType>{};
107 }
108 
addNodeToBlock(Block * block,Symbol kind,ArrayRef<Value * > inputs)109 Node* addNodeToBlock(Block* block, Symbol kind, ArrayRef<Value*> inputs) {
110   auto new_node = block->appendNode(block->owningGraph()->create(kind));
111   for (auto input : inputs) {
112     new_node->addInput(input);
113   }
114   return new_node;
115 }
116 
addInputToBlock(Block * block)117 Value* addInputToBlock(Block* block) {
118   return block->addInput();
119 }
120 
121 namespace {
ATenTypeToOnnxType_aux(at::ScalarType at_type)122 ::ONNX_NAMESPACE::TensorProto_DataType ATenTypeToOnnxType_aux(
123     at::ScalarType at_type) {
124   switch (at_type) {
125     case at::kDouble:
126       return ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
127     case at::kFloat:
128       return ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
129     case at::kHalf:
130       return ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
131     case at::kByte:
132       return ::ONNX_NAMESPACE::TensorProto_DataType_UINT8;
133     case at::kChar:
134       return ::ONNX_NAMESPACE::TensorProto_DataType_INT8;
135     case at::kShort:
136       return ::ONNX_NAMESPACE::TensorProto_DataType_INT16;
137     case at::kInt:
138       return ::ONNX_NAMESPACE::TensorProto_DataType_INT32;
139     case at::kLong:
140       return ::ONNX_NAMESPACE::TensorProto_DataType_INT64;
141     case at::kBool:
142       return ::ONNX_NAMESPACE::TensorProto_DataType_BOOL;
143     case at::kQInt8:
144       return ::ONNX_NAMESPACE::TensorProto_DataType_INT8;
145     case at::kQUInt8:
146       return ::ONNX_NAMESPACE::TensorProto_DataType_UINT8;
147     case at::kQInt32:
148       return ::ONNX_NAMESPACE::TensorProto_DataType_INT32;
149     default:
150       TORCH_CHECK(
151           false,
152           "ScalarType ",
153           toString(at_type),
154           " is an unexpected tensor scalar type");
155   }
156 }
157 } // namespace
158 
ATenTypeToOnnxType(at::ScalarType at_type)159 int ATenTypeToOnnxType(at::ScalarType at_type) {
160   return static_cast<int>(ATenTypeToOnnxType_aux(at_type));
161 }
162 
createONNXUnsqueeze(Graph * graph,Node * n_to_insert_before,Value * input,int axis,int opset_version)163 Node* createONNXUnsqueeze(
164     Graph* graph,
165     Node* n_to_insert_before,
166     Value* input,
167     int axis,
168     int opset_version) {
169   Node* unsqueeze_node = graph->create(onnx::Unsqueeze, 1);
170   unsqueeze_node->addInput(input);
171   unsqueeze_node->insertBefore(n_to_insert_before);
172   if (opset_version >= OPSET_VERSION_13) {
173     // ONNX spec sets `axes` as input for opset >= 13.
174     Node* unsqueeze_axes = graph->create(onnx::Constant, 1);
175     unsqueeze_axes->insertBefore(unsqueeze_node);
176     unsqueeze_axes->t_(
177         attr::value, at::unsqueeze(at::scalar_to_tensor(at::Scalar(axis)), 0));
178     unsqueeze_node->addInput(unsqueeze_axes->output());
179   } else {
180     // ONNX spec sets `axes` as attribute for opset < 13.
181     unsqueeze_node->is_(attr::axes, {0});
182   }
183   return unsqueeze_node;
184 }
185 
createONNXConstant(Graph * graph,Node * n_to_insert_before,at::Tensor value)186 Node* createONNXConstant(
187     Graph* graph,
188     Node* n_to_insert_before,
189     at::Tensor value) {
190   Node* constant_node = graph->create(onnx::Constant, 1);
191   constant_node->insertBefore(n_to_insert_before);
192   constant_node->t_(attr::value, std::move(value));
193   return constant_node;
194 }
195 
isValidToTransformToONNXConcatNode(Node * lc_node)196 bool isValidToTransformToONNXConcatNode(Node* lc_node) {
197   return !lc_node->inputs().empty();
198 }
199 
transformToONNXConcatNode(Graph * g,Node * lc_node,bool need_new_input,int opset_version)200 Node* transformToONNXConcatNode(
201     Graph* g,
202     Node* lc_node,
203     bool need_new_input,
204     int opset_version) {
205   // ListConstruct Int[] output case, we need to transform to ONNX
206   // Concat to ensure the output is a single tensor(dynamic) type in
207   // order to be consumed as inputs
208   std::vector<Value*> unsqueezed;
209   auto new_node = need_new_input ? g->return_node() : lc_node;
210 
211   for (auto* input : lc_node->inputs()) {
212     auto new_input =
213         need_new_input ? g->addInput()->copyMetadata(input) : input;
214     // This particular Concat operation concats along axis=0 and this requires
215     // inputs to the node to have the same shape along dim-0. To ensure this,
216     // unsqueeze nodes are added such that all shapes along dim-0 are 1.
217     // Certain inputs from ListConstruct Int[] could be combinations of scalars
218     // and 1-D tensors, For inputs that are already 1-D tensors, we skip the
219     // step of creating a corresponding unsqueeze node.
220     if (auto type = new_input->type()->cast<TensorType>()) {
221       if (type->dim() && type->dim() == 1U) {
222         unsqueezed.emplace_back(new_input);
223         continue;
224       }
225     }
226     Node* unsqueezed_node =
227         createONNXUnsqueeze(g, new_node, new_input, 0, opset_version);
228     unsqueezed_node->copyMetadata(lc_node);
229     unsqueezed.emplace_back(unsqueezed_node->output());
230   }
231 
232   Node* concat_node = need_new_input
233       ? g->insertNode(g->create(onnx::Concat, 1))
234       : g->create(onnx::Concat, 1)->insertBefore(lc_node);
235   concat_node->i_(attr::axis, 0);
236   for (auto v : unsqueezed) {
237     concat_node->addInput(v);
238   }
239 
240   return concat_node;
241 }
242 
ONNXLintGraph(const Block * b,std::vector<NodeKind> & n_miss_source_range,std::vector<NodeKind> & n_miss_scope)243 void ONNXLintGraph(
244     const Block* b,
245     std::vector<NodeKind>& n_miss_source_range,
246     std::vector<NodeKind>& n_miss_scope) {
247   for (const auto* n : b->nodes()) {
248     for (const auto* sub_b : n->blocks()) {
249       ONNXLintGraph(sub_b, n_miss_source_range, n_miss_scope);
250     }
251 
252     if (nullptr == n->sourceRange().source()) {
253       GRAPH_DEBUG("Node does not set sourceRange:", *n);
254       n_miss_source_range.emplace_back(n->kind());
255     }
256     if (n->scopeName().empty()) {
257       GRAPH_DEBUG("Node does not set scope:", *n);
258       n_miss_scope.emplace_back(n->kind());
259     }
260   }
261 }
262 
ONNXLintGraph(const std::shared_ptr<Graph> & graph)263 void ONNXLintGraph(const std::shared_ptr<Graph>& graph) {
264   // Print nodes that do not have scope/source range covered.
265   std::vector<NodeKind> n_miss_source_range, n_miss_scope;
266   ONNXLintGraph(graph->block(), n_miss_source_range, n_miss_scope);
267   auto count_const = [](const std::vector<NodeKind>& vec) -> size_t {
268     size_t count = 0;
269     for (auto k : vec) {
270       switch (k) {
271         case prim::Constant:
272         case prim::ListConstruct:
273         case onnx::Constant:
274           count++;
275           break;
276       }
277     }
278     return count;
279   };
280   auto const_count_src = count_const(n_miss_source_range);
281   auto const_count_scope = count_const(n_miss_scope);
282   GRAPH_UPDATE(
283       "Missing source range.\n",
284       "Total ",
285       n_miss_source_range.size(),
286       " nodes. Including ",
287       const_count_src,
288       " constants.");
289   GRAPH_UPDATE(
290       "Missing scope.\n",
291       "Total ",
292       n_miss_scope.size(),
293       " nodes. Including ",
294       const_count_scope,
295       " constants.");
296 }
297 
298 } // namespace torch::jit
299