1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/passes/onnx/helper.h>
3 #include <torch/csrc/onnx/back_compat.h>
4
5 #include <ATen/ScalarOps.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #else
10 #include <ATen/ops/unsqueeze.h>
11 #endif
12
13 #include <onnx/onnx_pb.h>
14
15 namespace torch::jit {
16 namespace onnx {
17 using namespace ::c10::onnx;
18
19 } // namespace onnx
20
buildValueToParamsMap(Block * b,const ParamMap & paramsDict)21 ValueToParamPairMap buildValueToParamsMap(
22 Block* b,
23 const ParamMap& paramsDict) {
24 ValueToParamPairMap valsToParamsMap;
25 for (auto& input : b->inputs()) {
26 auto it = paramsDict.find(input->debugName());
27 if (it != paramsDict.end()) {
28 valsToParamsMap.emplace(input, *it);
29 }
30 }
31 return valsToParamsMap;
32 }
33
eraseUnusedBlockInputs(Block * b)34 void eraseUnusedBlockInputs(Block* b) {
35 for (size_t i_1 = b->inputs().size(); i_1 > 0; --i_1) {
36 size_t i = i_1 - 1;
37 if (!b->inputs().at(i)->hasUses()) {
38 b->eraseInput(i);
39 }
40 }
41 }
42
eraseUnusedValuesFromMap(ValueToParamPairMap & valsToParamsMap)43 void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap) {
44 auto it = valsToParamsMap.begin();
45 while (it != valsToParamsMap.end()) {
46 if (!it->first->hasUses()) {
47 it = valsToParamsMap.erase(it);
48 } else {
49 ++it;
50 }
51 }
52 }
53
buildParamsMapFromValueToParamsMap(const ValueToParamPairMap & valsToParamsMap,ParamMap & paramsDict)54 void buildParamsMapFromValueToParamsMap(
55 const ValueToParamPairMap& valsToParamsMap,
56 ParamMap& paramsDict) {
57 paramsDict.clear();
58 for (const auto& nameTensorParamPair : valsToParamsMap) {
59 paramsDict.insert(nameTensorParamPair.second);
60 }
61 }
62
ONNXTypeToATenType(int32_t onnx_type)63 std::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type) {
64 switch (onnx_type) {
65 case ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED:
66 return at::ScalarType::Undefined;
67 case ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
68 return at::kFloat;
69 case ::ONNX_NAMESPACE::TensorProto_DataType_UINT8:
70 return at::kByte;
71 case ::ONNX_NAMESPACE::TensorProto_DataType_INT8:
72 return at::kChar;
73 case ::ONNX_NAMESPACE::TensorProto_DataType_INT16:
74 return at::kShort;
75 case ::ONNX_NAMESPACE::TensorProto_DataType_INT32:
76 return at::kInt;
77 case ::ONNX_NAMESPACE::TensorProto_DataType_INT64:
78 return at::kLong;
79 case ::ONNX_NAMESPACE::TensorProto_DataType_BOOL:
80 return at::kBool;
81 case ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
82 return at::kHalf;
83 case ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
84 return at::kDouble;
85 case ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64:
86 return at::kComplexFloat;
87 case ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128:
88 return at::kComplexDouble;
89 case ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
90 return at::kBFloat16;
91 case ::torch::onnx::TensorProto_DataType_FLOAT8E5M2:
92 return at::kFloat8_e5m2;
93 case ::torch::onnx::TensorProto_DataType_FLOAT8E5M2FNUZ:
94 return at::kFloat8_e5m2fnuz;
95 case ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN:
96 return at::kFloat8_e4m3fn;
97 case ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FNUZ:
98 return at::kFloat8_e4m3fnuz;
99 default:
100 TORCH_CHECK(
101 false,
102 "ONNX type ",
103 onnx_type,
104 " is an unexpected tensor scalar type");
105 }
106 return std::optional<at::ScalarType>{};
107 }
108
addNodeToBlock(Block * block,Symbol kind,ArrayRef<Value * > inputs)109 Node* addNodeToBlock(Block* block, Symbol kind, ArrayRef<Value*> inputs) {
110 auto new_node = block->appendNode(block->owningGraph()->create(kind));
111 for (auto input : inputs) {
112 new_node->addInput(input);
113 }
114 return new_node;
115 }
116
addInputToBlock(Block * block)117 Value* addInputToBlock(Block* block) {
118 return block->addInput();
119 }
120
121 namespace {
ATenTypeToOnnxType_aux(at::ScalarType at_type)122 ::ONNX_NAMESPACE::TensorProto_DataType ATenTypeToOnnxType_aux(
123 at::ScalarType at_type) {
124 switch (at_type) {
125 case at::kDouble:
126 return ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
127 case at::kFloat:
128 return ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
129 case at::kHalf:
130 return ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
131 case at::kByte:
132 return ::ONNX_NAMESPACE::TensorProto_DataType_UINT8;
133 case at::kChar:
134 return ::ONNX_NAMESPACE::TensorProto_DataType_INT8;
135 case at::kShort:
136 return ::ONNX_NAMESPACE::TensorProto_DataType_INT16;
137 case at::kInt:
138 return ::ONNX_NAMESPACE::TensorProto_DataType_INT32;
139 case at::kLong:
140 return ::ONNX_NAMESPACE::TensorProto_DataType_INT64;
141 case at::kBool:
142 return ::ONNX_NAMESPACE::TensorProto_DataType_BOOL;
143 case at::kQInt8:
144 return ::ONNX_NAMESPACE::TensorProto_DataType_INT8;
145 case at::kQUInt8:
146 return ::ONNX_NAMESPACE::TensorProto_DataType_UINT8;
147 case at::kQInt32:
148 return ::ONNX_NAMESPACE::TensorProto_DataType_INT32;
149 default:
150 TORCH_CHECK(
151 false,
152 "ScalarType ",
153 toString(at_type),
154 " is an unexpected tensor scalar type");
155 }
156 }
157 } // namespace
158
ATenTypeToOnnxType(at::ScalarType at_type)159 int ATenTypeToOnnxType(at::ScalarType at_type) {
160 return static_cast<int>(ATenTypeToOnnxType_aux(at_type));
161 }
162
createONNXUnsqueeze(Graph * graph,Node * n_to_insert_before,Value * input,int axis,int opset_version)163 Node* createONNXUnsqueeze(
164 Graph* graph,
165 Node* n_to_insert_before,
166 Value* input,
167 int axis,
168 int opset_version) {
169 Node* unsqueeze_node = graph->create(onnx::Unsqueeze, 1);
170 unsqueeze_node->addInput(input);
171 unsqueeze_node->insertBefore(n_to_insert_before);
172 if (opset_version >= OPSET_VERSION_13) {
173 // ONNX spec sets `axes` as input for opset >= 13.
174 Node* unsqueeze_axes = graph->create(onnx::Constant, 1);
175 unsqueeze_axes->insertBefore(unsqueeze_node);
176 unsqueeze_axes->t_(
177 attr::value, at::unsqueeze(at::scalar_to_tensor(at::Scalar(axis)), 0));
178 unsqueeze_node->addInput(unsqueeze_axes->output());
179 } else {
180 // ONNX spec sets `axes` as attribute for opset < 13.
181 unsqueeze_node->is_(attr::axes, {0});
182 }
183 return unsqueeze_node;
184 }
185
createONNXConstant(Graph * graph,Node * n_to_insert_before,at::Tensor value)186 Node* createONNXConstant(
187 Graph* graph,
188 Node* n_to_insert_before,
189 at::Tensor value) {
190 Node* constant_node = graph->create(onnx::Constant, 1);
191 constant_node->insertBefore(n_to_insert_before);
192 constant_node->t_(attr::value, std::move(value));
193 return constant_node;
194 }
195
isValidToTransformToONNXConcatNode(Node * lc_node)196 bool isValidToTransformToONNXConcatNode(Node* lc_node) {
197 return !lc_node->inputs().empty();
198 }
199
transformToONNXConcatNode(Graph * g,Node * lc_node,bool need_new_input,int opset_version)200 Node* transformToONNXConcatNode(
201 Graph* g,
202 Node* lc_node,
203 bool need_new_input,
204 int opset_version) {
205 // ListConstruct Int[] output case, we need to transform to ONNX
206 // Concat to ensure the output is a single tensor(dynamic) type in
207 // order to be consumed as inputs
208 std::vector<Value*> unsqueezed;
209 auto new_node = need_new_input ? g->return_node() : lc_node;
210
211 for (auto* input : lc_node->inputs()) {
212 auto new_input =
213 need_new_input ? g->addInput()->copyMetadata(input) : input;
214 // This particular Concat operation concats along axis=0 and this requires
215 // inputs to the node to have the same shape along dim-0. To ensure this,
216 // unsqueeze nodes are added such that all shapes along dim-0 are 1.
217 // Certain inputs from ListConstruct Int[] could be combinations of scalars
218 // and 1-D tensors, For inputs that are already 1-D tensors, we skip the
219 // step of creating a corresponding unsqueeze node.
220 if (auto type = new_input->type()->cast<TensorType>()) {
221 if (type->dim() && type->dim() == 1U) {
222 unsqueezed.emplace_back(new_input);
223 continue;
224 }
225 }
226 Node* unsqueezed_node =
227 createONNXUnsqueeze(g, new_node, new_input, 0, opset_version);
228 unsqueezed_node->copyMetadata(lc_node);
229 unsqueezed.emplace_back(unsqueezed_node->output());
230 }
231
232 Node* concat_node = need_new_input
233 ? g->insertNode(g->create(onnx::Concat, 1))
234 : g->create(onnx::Concat, 1)->insertBefore(lc_node);
235 concat_node->i_(attr::axis, 0);
236 for (auto v : unsqueezed) {
237 concat_node->addInput(v);
238 }
239
240 return concat_node;
241 }
242
ONNXLintGraph(const Block * b,std::vector<NodeKind> & n_miss_source_range,std::vector<NodeKind> & n_miss_scope)243 void ONNXLintGraph(
244 const Block* b,
245 std::vector<NodeKind>& n_miss_source_range,
246 std::vector<NodeKind>& n_miss_scope) {
247 for (const auto* n : b->nodes()) {
248 for (const auto* sub_b : n->blocks()) {
249 ONNXLintGraph(sub_b, n_miss_source_range, n_miss_scope);
250 }
251
252 if (nullptr == n->sourceRange().source()) {
253 GRAPH_DEBUG("Node does not set sourceRange:", *n);
254 n_miss_source_range.emplace_back(n->kind());
255 }
256 if (n->scopeName().empty()) {
257 GRAPH_DEBUG("Node does not set scope:", *n);
258 n_miss_scope.emplace_back(n->kind());
259 }
260 }
261 }
262
ONNXLintGraph(const std::shared_ptr<Graph> & graph)263 void ONNXLintGraph(const std::shared_ptr<Graph>& graph) {
264 // Print nodes that do not have scope/source range covered.
265 std::vector<NodeKind> n_miss_source_range, n_miss_scope;
266 ONNXLintGraph(graph->block(), n_miss_source_range, n_miss_scope);
267 auto count_const = [](const std::vector<NodeKind>& vec) -> size_t {
268 size_t count = 0;
269 for (auto k : vec) {
270 switch (k) {
271 case prim::Constant:
272 case prim::ListConstruct:
273 case onnx::Constant:
274 count++;
275 break;
276 }
277 }
278 return count;
279 };
280 auto const_count_src = count_const(n_miss_source_range);
281 auto const_count_scope = count_const(n_miss_scope);
282 GRAPH_UPDATE(
283 "Missing source range.\n",
284 "Total ",
285 n_miss_source_range.size(),
286 " nodes. Including ",
287 const_count_src,
288 " constants.");
289 GRAPH_UPDATE(
290 "Missing scope.\n",
291 "Total ",
292 n_miss_scope.size(),
293 " nodes. Including ",
294 const_count_scope,
295 " constants.");
296 }
297
298 } // namespace torch::jit
299