1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/passes/onnx/constant_fold.h>
3 #include <torch/csrc/jit/passes/onnx/helper.h>
4
5 #include <ATen/Functions.h>
6
7 #include <c10/util/Exception.h>
8 #include <c10/util/irange.h>
9 #include <algorithm>
10 #include <optional>
11
12 namespace torch::jit {
13
14 namespace onnx {
15 using namespace ::c10::onnx;
16 }
17
18 namespace onnx_constant_fold {
19
20 enum OnnxType : int {
21 ONNX_FLOAT = 1,
22 ONNX_UINT8,
23 ONNX_INT8,
24 ONNX_UINT16,
25 ONNX_INT16,
26 ONNX_INT32,
27 ONNX_INT64,
28 ONNX_FLOAT16 = 10,
29 ONNX_DOUBLE,
30 ONNX_UINT32,
31 };
32
33 std::unordered_map<int, at::ScalarType> onnxTypeToScalarTypeMap = {
34 // Only conversion of ONNX numeric types is included here.
35 // Unsigned ONNX types are mapped to the next higher signed
36 // ScalarType type.
37 {ONNX_FLOAT, at::kFloat},
38 {ONNX_UINT8, at::kByte},
39 {ONNX_INT8, at::kChar},
40 {ONNX_UINT16, at::kInt},
41 {ONNX_INT16, at::kShort},
42 {ONNX_INT32, at::kInt},
43 {ONNX_INT64, at::kLong},
44 {ONNX_FLOAT16, at::kFloat},
45 {ONNX_DOUBLE, at::kDouble},
46 {ONNX_UINT32, at::kLong},
47 };
48
handleNegativeStartEndIndex(int64_t & start,int64_t & end,int64_t & axis,c10::IntArrayRef tensorSizes)49 void handleNegativeStartEndIndex(
50 int64_t& start,
51 int64_t& end,
52 int64_t& axis,
53 c10::IntArrayRef tensorSizes) {
54 if (start < 0) {
55 start = tensorSizes[axis] + start;
56 }
57 if (end < 0) {
58 end = tensorSizes[axis] + end;
59 }
60 // index higher than dimension is treated as the end.
61 if (end > tensorSizes[axis]) {
62 end = tensorSizes[axis];
63 }
64 }
65
runTorchSlice_opset9(const Node * node,std::vector<at::Tensor> & inputTensorValues)66 std::optional<at::Tensor> runTorchSlice_opset9(
67 const Node* node,
68 std::vector<at::Tensor>& inputTensorValues) {
69 assert(inputTensorValues.size() == 1);
70 if (inputTensorValues.size() != 1) {
71 TORCH_WARN(
72 "Constant folding - Invalid number of inputs found for opset 9 "
73 "onnx::Slice op. Constant folding not applied.");
74 return std::nullopt;
75 }
76 if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) {
77 return std::nullopt;
78 }
79 auto startsAttr = node->is(attr::starts);
80 auto endsAttr = node->is(attr::ends);
81 if (startsAttr.size() != endsAttr.size()) {
82 return std::nullopt;
83 }
84 std::vector<int64_t> axesAttr;
85 if (node->hasAttributeS("axes")) {
86 axesAttr = node->is(attr::axes);
87 } else {
88 axesAttr.resize(startsAttr.size());
89 std::iota(axesAttr.begin(), axesAttr.end(), 0);
90 }
91 auto updated_val = inputTensorValues[0];
92 for (const auto i : c10::irange(axesAttr.size())) {
93 // ONNX slice accepts negative starts and ends values.
94 int64_t axis = axesAttr[i], start = startsAttr[i], end = endsAttr[i];
95 // ONNX slice accepts negative axis, fix this for aten op
96 axis += axis < 0 ? inputTensorValues[0].sizes().size() : 0;
97 handleNegativeStartEndIndex(start, end, axis, updated_val.sizes());
98 int64_t length = end - start;
99 if (length < 0 || start > updated_val.sizes()[axis] - length)
100 return std::nullopt;
101 updated_val = at::narrow(updated_val, axis, start, length);
102 }
103 return std::optional<at::Tensor>(updated_val);
104 }
105
runTorchSlice_opset10(const Node * node,std::vector<at::Tensor> & inputTensorValues)106 std::optional<at::Tensor> runTorchSlice_opset10(
107 const Node* node,
108 std::vector<at::Tensor>& inputTensorValues) {
109 const int maxSliceInputCount = 5;
110 const int minSliceInputCount = 3;
111 if (inputTensorValues.size() < minSliceInputCount ||
112 inputTensorValues.size() > maxSliceInputCount) {
113 TORCH_WARN(
114 "Constant folding - Invalid number of inputs found for opset opset >= 10 onnx::Slice op. "
115 "Constant folding not applied.");
116 return std::nullopt;
117 }
118 // Checking validity of 'starts' and 'ends' input
119 if (inputTensorValues[1].sizes().size() != 1 ||
120 inputTensorValues[2].sizes().size() != 1) {
121 TORCH_WARN(
122 "Constant folding - Invalid 'starts' or 'ends' inputs found for opset >= 10 onnx::Slice op. "
123 "Constant folding not applied.");
124 return std::nullopt;
125 }
126 if (inputTensorValues[1].sizes()[0] != inputTensorValues[2].sizes()[0]) {
127 // Number of elements of 'starts' and 'ends' 1-D input tensors should be the
128 // same
129 return std::nullopt;
130 }
131 // Checking 'axes' input, if available.
132 std::vector<int64_t> axes;
133 if (inputTensorValues.size() > 3) {
134 if (inputTensorValues[3].sizes().size() != 1) {
135 TORCH_WARN(
136 "Constant folding - Invalid 'axes' input found for opset >= 10 onnx::Slice op. "
137 "Constant folding not applied.");
138 return std::nullopt;
139 }
140 if (inputTensorValues[3].sizes()[0] != inputTensorValues[1].sizes()[0]) {
141 // Number of elements of 'axes' and 'ends' 1-D input tensors should be the
142 // same
143 TORCH_WARN(
144 "Constant folding - Invalid 'axes' or 'ends' inputs found for opset >= 10 onnx::Slice op. "
145 "Constant folding not applied.");
146 return std::nullopt;
147 }
148 auto axes_a = inputTensorValues[3].accessor<int64_t, 1>();
149 axes.resize(inputTensorValues[3].sizes()[0]);
150 // ONNX slice accepts negative axis, fix this for aten op
151 for (const auto i : c10::irange(inputTensorValues[3].sizes()[0])) {
152 axes[i] = axes_a[i] < 0 ? axes_a[i] + inputTensorValues[0].sizes().size()
153 : axes_a[i];
154 }
155 } else {
156 axes = std::vector<int64_t>(inputTensorValues[1].sizes()[0], 0);
157 }
158 // Checking 'steps' input, if available.
159 if (inputTensorValues.size() > 4) {
160 if (inputTensorValues[4].sizes().size() != 1) {
161 TORCH_WARN(
162 "Constant folding - Invalid 'steps' input found for opset >= 10 onnx::Slice op. "
163 "Constant folding not applied.");
164 return std::nullopt;
165 }
166 if (inputTensorValues[4].sizes()[0] != inputTensorValues[1].sizes()[0]) {
167 // Number of elements of 'steps' and 'ends' 1-D input tensors should be
168 // the same
169 TORCH_WARN(
170 "Constant folding - Invalid 'steps' or 'ends' inputs found for opset >= 10 onnx::Slice op. "
171 "Constant folding not applied.");
172 return std::nullopt;
173 }
174 auto steps_a = inputTensorValues[4].accessor<int64_t, 1>();
175 for (const auto i : c10::irange(inputTensorValues[4].sizes()[0])) {
176 // Only steps == 1 are supported for constant-folding.
177 if (steps_a[i] != 1) {
178 TORCH_WARN(
179 "Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. "
180 "Constant folding not applied.");
181 return std::nullopt;
182 }
183 }
184 }
185 auto starts_a = inputTensorValues[1].accessor<int64_t, 1>();
186 auto ends_a = inputTensorValues[2].accessor<int64_t, 1>();
187 auto updated_val = inputTensorValues[0];
188 for (const auto i : c10::irange(inputTensorValues[1].sizes()[0])) {
189 // ONNX slice accepts negative starts and ends values.
190 int64_t start = starts_a[i], end = ends_a[i], axis = axes[i];
191 handleNegativeStartEndIndex(start, end, axis, updated_val.sizes());
192 int64_t length = end - start;
193 if (length < 0 || start > updated_val.sizes()[axis] - length)
194 return std::nullopt;
195 updated_val = at::narrow(updated_val, axis, start, length);
196 }
197 return std::optional<at::Tensor>(updated_val);
198 }
199
200 // Refer to AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF
runTorchArange_opset11(const Node * node,const std::vector<at::Tensor> & inputTensorValues)201 at::Tensor runTorchArange_opset11(
202 const Node* node,
203 const std::vector<at::Tensor>& inputTensorValues) {
204 TORCH_INTERNAL_ASSERT(inputTensorValues.size() == 3);
205 auto dtype = inputTensorValues[0].scalar_type();
206 at::Tensor updated_val;
207 switch (dtype) {
208 case at::ScalarType::Float: {
209 auto start = inputTensorValues[0].item<float>();
210 auto end = inputTensorValues[1].item<float>();
211 auto step = inputTensorValues[2].item<float>();
212 updated_val = at::arange(start, end, step);
213 break;
214 }
215 case at::ScalarType::Double: {
216 auto start = inputTensorValues[0].item<double>();
217 auto end = inputTensorValues[1].item<double>();
218 auto step = inputTensorValues[2].item<double>();
219 updated_val = at::arange(start, end, step);
220 break;
221 }
222 case at::ScalarType::Short: {
223 auto start = inputTensorValues[0].item<int16_t>();
224 auto end = inputTensorValues[1].item<int16_t>();
225 auto step = inputTensorValues[2].item<int16_t>();
226 updated_val = at::arange(start, end, step);
227 break;
228 }
229 case at::ScalarType::Int: {
230 auto start = inputTensorValues[0].item<int>();
231 auto end = inputTensorValues[1].item<int>();
232 auto step = inputTensorValues[2].item<int>();
233 updated_val = at::arange(start, end, step);
234 break;
235 }
236 case at::ScalarType::Long: {
237 auto start = inputTensorValues[0].item<int64_t>();
238 auto end = inputTensorValues[1].item<int64_t>();
239 auto step = inputTensorValues[2].item<int64_t>();
240 updated_val = at::arange(start, end, step);
241 break;
242 }
243 default: {
244 TORCH_WARN(
245 "Constant folding - ONNX Range type: ", dtype, " is not supported.");
246 }
247 }
248 return updated_val;
249 }
250
IntToTensor(int64_t value)251 at::Tensor IntToTensor(int64_t value) {
252 auto options = c10::TensorOptions().dtype(at::kLong).device(at::kCPU);
253 std::vector<int64_t> size_data = {value};
254 auto f = at::from_blob(size_data.data(), {1}, at::kLong).to(at::kCPU);
255 // Need copy here
256 at::Tensor f_copy = at::empty({1}, options);
257 f_copy.copy_(f);
258 return at::squeeze(f_copy, 0);
259 }
260
runTorchBackendForOnnx(const Node * node,std::vector<at::Tensor> & inputTensorValues,int opset_version)261 std::optional<at::Tensor> runTorchBackendForOnnx(
262 const Node* node,
263 std::vector<at::Tensor>& inputTensorValues,
264 int opset_version) {
265 at::Tensor updated_val;
266 if (node->kind() == onnx::Slice) {
267 if (opset_version == ONNX_OPSET_9) {
268 return runTorchSlice_opset9(node, inputTensorValues);
269 } else if (opset_version >= ONNX_OPSET_10) {
270 return runTorchSlice_opset10(node, inputTensorValues);
271 } else {
272 TORCH_WARN(
273 "Constant folding - unsupported opset version. Constant folding not applied.");
274 return std::nullopt;
275 }
276 } else if (node->kind() == onnx::Concat) {
277 if (!node->hasAttributeS("axis")) {
278 return std::nullopt;
279 }
280 updated_val =
281 at::cat(at::TensorList(inputTensorValues), node->i(attr::axis));
282 return std::optional<at::Tensor>(updated_val);
283 } else if (node->kind() == onnx::Sqrt) {
284 updated_val = at::sqrt(inputTensorValues[0]);
285 return std::optional<at::Tensor>(updated_val);
286 } else if (node->kind() == onnx::Div) {
287 // One example shows at::div(CPULongType, CPULongType) = CPUFloatType,
288 // So we add a cast below.
289 updated_val = at::div(inputTensorValues[0], inputTensorValues[1]);
290 if (inputTensorValues[0].scalar_type() ==
291 inputTensorValues[1].scalar_type()) {
292 updated_val = updated_val.to(inputTensorValues[0].scalar_type());
293 }
294 return std::optional<at::Tensor>(updated_val);
295 } else if (node->kind() == onnx::Mul) {
296 updated_val = at::mul(inputTensorValues[0], inputTensorValues[1]);
297 return std::optional<at::Tensor>(updated_val);
298 } else if (node->kind() == onnx::Sub) {
299 updated_val = at::sub(inputTensorValues[0], inputTensorValues[1]);
300 return std::optional<at::Tensor>(updated_val);
301 } else if (node->kind() == onnx::Add) {
302 updated_val = at::add(inputTensorValues[0], inputTensorValues[1]);
303 return std::optional<at::Tensor>(updated_val);
304 } else if (node->kind() == onnx::Unsqueeze) {
305 if (opset_version >= ONNX_OPSET_13) {
306 assert(inputTensorValues.size() == 2);
307 // Checking validity of 'axes' input
308 if (inputTensorValues[1].sizes().size() != 1) {
309 TORCH_WARN(
310 "Constant folding - Invalid 'axes' inputs found for opset 13 onnx::Unsqueeze op. "
311 "Constant folding not applied.");
312 return std::nullopt;
313 }
314 auto axes_a = inputTensorValues[1].accessor<int64_t, 1>();
315 std::vector<int64_t> axes;
316 for (int64_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
317 // ONNX unsqueeze accepts negative axes
318 // From https://pytorch.org/docs/stable/generated/torch.unsqueeze.html
319 // Negative dim will correspond to unsqueeze() applied at dim = dim +
320 // input.dim() + 1.
321 axes_a[i] +=
322 axes_a[i] < 0 ? inputTensorValues[0].sizes().size() + 1 : 0;
323 axes.push_back(axes_a[i]);
324 }
325 std::sort(axes.begin(), axes.end());
326 updated_val = inputTensorValues[0];
327 for (int64_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
328 updated_val = at::unsqueeze(updated_val, axes[i]);
329 }
330 return std::optional<at::Tensor>(updated_val);
331 } else if (opset_version >= ONNX_OPSET_9) {
332 assert(inputTensorValues.size() == 1);
333 if (!node->hasAttributeS("axes")) {
334 return std::nullopt;
335 }
336 updated_val = inputTensorValues[0];
337 std::vector<int64_t> axesAttr = node->is(attr::axes);
338 std::sort(axesAttr.begin(), axesAttr.end());
339 for (auto axis : axesAttr) {
340 updated_val = at::unsqueeze(updated_val, axis);
341 }
342 return std::optional<at::Tensor>(updated_val);
343 } else {
344 TORCH_WARN(
345 "Constant folding - unsupported opset version. "
346 "Constant folding not applied.");
347 return std::nullopt;
348 }
349 } else if (node->kind() == onnx::Squeeze) {
350 assert(inputTensorValues.size() == 2 || inputTensorValues.size() == 1);
351 if (opset_version >= ONNX_OPSET_13) {
352 // Squeeze version 13 input axes is optional, inputTensorValues.size() ==
353 // 1 means axes equal to None
354 updated_val = inputTensorValues[0];
355 if (inputTensorValues.size() == 2) {
356 // Checking validity of 'axes' input
357 if (inputTensorValues[1].sizes().size() != 1) {
358 TORCH_WARN(
359 "Constant folding - Invalid 'axes' inputs found for opset 13 onnx::Squeeze op. "
360 "Constant folding not applied.");
361 return std::nullopt;
362 }
363 auto axes_a = inputTensorValues[1].accessor<int64_t, 1>();
364 std::vector<int64_t> axes;
365 for (int64_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
366 // ONNX Squeeze accepts negative axes
367 axes_a[i] += axes_a[i] < 0 ? inputTensorValues[0].sizes().size() : 0;
368 axes.push_back(axes_a[i]);
369 }
370 std::sort(axes.begin(), axes.end());
371 for (int64_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
372 updated_val = at::squeeze(updated_val, axes[i]);
373 }
374 }
375 return std::optional<at::Tensor>(updated_val);
376 } else if (opset_version >= ONNX_OPSET_9) {
377 assert(inputTensorValues.size() == 1);
378 updated_val = inputTensorValues[0];
379 if (node->hasAttributeS("axes")) {
380 std::vector<int64_t> axesAttr = node->is(attr::axes);
381 std::sort(axesAttr.begin(), axesAttr.end());
382 for (auto axis : axesAttr) {
383 updated_val = at::squeeze(updated_val, axis);
384 }
385 }
386 return std::optional<at::Tensor>(updated_val);
387 } else {
388 TORCH_WARN(
389 "Constant folding - unsupported opset version. "
390 "Constant folding not applied.");
391 return std::nullopt;
392 }
393 } else if (node->kind() == onnx::Transpose) {
394 assert(inputTensorValues.size() == 1);
395 if (!node->hasAttributeS("perm")) {
396 return std::nullopt;
397 }
398 updated_val = inputTensorValues[0].permute(node->is(attr::perm));
399 return std::optional<at::Tensor>(updated_val);
400 } else if (node->kind() == onnx::Cast) {
401 assert(inputTensorValues.size() == 1);
402 if (node->hasAttributeS("to") && ONNXTypeToATenType(node->i(attr::to))) {
403 updated_val = inputTensorValues[0].to(
404 ONNXTypeToATenType(node->i(attr::to)).value());
405 return std::optional<at::Tensor>(updated_val);
406 }
407 return std::nullopt;
408 } else if (node->kind() == onnx::Reshape) {
409 assert(inputTensorValues.size() == 2);
410 updated_val = inputTensorValues[0];
411 std::vector<int64_t> shape(inputTensorValues[1].sizes()[0], 0);
412 auto shape_a = inputTensorValues[1].accessor<int64_t, 1>();
413 assert(inputTensorValues[1].sizes()[0] >= 0);
414 // Set value of allowzero
415 int64_t allowzero = 0;
416 if (node->hasAttributeS("allowzero")) {
417 allowzero = node->i(attr::allowzero);
418 }
419 for (size_t i = 0; i < (size_t)(inputTensorValues[1].sizes()[0]); ++i) {
420 // All shape dim values should be >= -1
421 // onnx::Reshape supports a shape dim value to be zero, in
422 // which case the actual dim value remains unchanged. However,
423 // at::reshape does not support shape dim value to be zero
424 assert(shape_a[i] >= -1);
425 if (shape_a[i] == 0 && !allowzero) {
426 if (i >= inputTensorValues[0].sizes().size()) {
427 throw std::runtime_error(
428 "Dimension with value 0 exceeds the input size dimensions.");
429 }
430 shape[i] = inputTensorValues[0].sizes()[i];
431 } else {
432 shape[i] = shape_a[i];
433 }
434 }
435 return std::optional<at::Tensor>(at::reshape(updated_val, shape));
436 } else if (node->kind() == onnx::Shape) {
437 TORCH_INTERNAL_ASSERT(inputTensorValues.size() == 1);
438 updated_val = at::_shape_as_tensor(inputTensorValues[0]);
439 return std::optional<at::Tensor>(updated_val);
440 } else if (node->kind() == onnx::ReduceL1 || node->kind() == onnx::ReduceL2) {
441 assert(inputTensorValues.size() == 1);
442 if (!node->hasAttributeS("axes")) {
443 return std::nullopt;
444 }
445 if (!node->hasAttributeS("keepdims")) {
446 return std::nullopt;
447 }
448 int p = node->kind() == onnx::ReduceL1 ? 1 : 2;
449 updated_val = at::norm(
450 inputTensorValues[0], p, node->is(attr::axes), node->i(attr::keepdims));
451 return std::optional<at::Tensor>(updated_val);
452 } else if (node->kind() == onnx::ReduceProd) {
453 int64_t rank = inputTensorValues[0].sizes().size();
454 std::vector<int64_t> axes;
455 if (!node->hasAttributeS("axes")) {
456 axes = std::vector<int64_t>(rank);
457 std::iota(axes.rbegin(), axes.rend(), 0);
458 } else {
459 for (const auto& axis : node->is(attr::axes)) {
460 axes.emplace_back(axis < 0 ? axis + rank : axis);
461 }
462 std::sort(axes.begin(), axes.end(), std::greater<>());
463 }
464
465 bool keepdims =
466 node->hasAttributeS("keepdims") ? node->i(attr::keepdims) : true;
467 updated_val = inputTensorValues[0];
468 for (const auto& axis : axes) {
469 updated_val = at::prod(updated_val, axis, keepdims);
470 }
471 return std::optional<at::Tensor>(updated_val);
472 } else if (node->kind() == onnx::Gather) {
473 assert(inputTensorValues.size() == 2);
474 // default axis = 0
475 int64_t axis = 0;
476 if (node->hasAttributeS("axis")) {
477 axis = node->i(attr::axis);
478 }
479 // If axis attribute for onnx::Gather has a value less than 0,
480 // It needs to be adjusted (+= dim sizes) for aten op
481 axis += axis < 0 ? inputTensorValues[0].sizes().size() : 0;
482 at::Tensor indices = inputTensorValues[1];
483 auto q = indices.dim();
484 // at::index_select only supports indices with rank <= 1.
485 // See https://pytorch.org/docs/main/generated/torch.index_select.html
486 if (q > 1) {
487 return std::nullopt;
488 }
489 // If the device of indices tensor is not the same with it of the input
490 // tensor, move it to the device of the input tensor
491 if (inputTensorValues[0].device() != indices.device()) {
492 indices = indices.to(inputTensorValues[0].device());
493 }
494 // If indices input for onnx::Gather has a value less than 0,
495 // It needs to be adjusted (+= dim value) for aten op
496 auto less_mask = at::lt(indices, 0);
497 auto indices_corr = at::add(indices, inputTensorValues[0].sizes()[axis]);
498 auto indices_masked = at::where(less_mask, indices_corr, indices);
499 updated_val = at::index_select(inputTensorValues[0], axis, indices_masked);
500 // If rank of indices is 0, rank of output tensor should be
501 // rank_of_input - 1.
502 if (q < 1) {
503 updated_val = updated_val.squeeze(axis);
504 }
505 return std::optional<at::Tensor>(updated_val);
506 } else if (node->kind() == onnx::Range) {
507 updated_val = runTorchArange_opset11(node, inputTensorValues);
508 return std::optional<at::Tensor>(updated_val);
509 } else if (node->kind() == onnx::Where) {
510 updated_val = at::where(
511 inputTensorValues[0], inputTensorValues[1], inputTensorValues[2]);
512 return std::optional<at::Tensor>(updated_val);
513 } else if (node->kind() == onnx::Equal) {
514 updated_val = at::eq(inputTensorValues[0], inputTensorValues[1]);
515 return std::optional<at::Tensor>(updated_val);
516 } else if (node->kind() == onnx::Greater) {
517 updated_val = at::greater(inputTensorValues[0], inputTensorValues[1]);
518 return std::optional<at::Tensor>(updated_val);
519 } else if (node->kind() == onnx::Less) {
520 updated_val = at::less(inputTensorValues[0], inputTensorValues[1]);
521 return std::optional<at::Tensor>(updated_val);
522 } else if (node->kind() == onnx::Neg) {
523 updated_val = at::neg(inputTensorValues[0]);
524 return std::optional<at::Tensor>(updated_val);
525 } else if (node->kind() == onnx::Not) {
526 auto ones =
527 at::ones(inputTensorValues[0].sizes(), inputTensorValues[0].dtype());
528 updated_val = at::ne(inputTensorValues[0], ones);
529 return std::optional<at::Tensor>(updated_val);
530 } else if (node->kind() == onnx::Size) {
531 int64_t total_size = 1;
532 for (auto size : inputTensorValues[0].sizes()) {
533 total_size *= size;
534 }
535 return std::optional<at::Tensor>(IntToTensor(total_size));
536 } else if (node->kind() == onnx::Softmax) {
537 int64_t axis = node->hasAttributeS("axis") ? node->i(attr::axis) : -1;
538 updated_val = at::softmax(inputTensorValues[0], axis);
539 return std::optional<at::Tensor>(updated_val);
540 } else {
541 return std::nullopt;
542 }
543 }
544
isConstant(Value * val,const ValueToParamPairMap & valsToParamsMap)545 bool isConstant(Value* val, const ValueToParamPairMap& valsToParamsMap) {
546 auto parentNode = val->node();
547 return (parentNode->kind() == prim::Param &&
548 valsToParamsMap.find(val) !=
549 valsToParamsMap
550 .end()) || // Checks val is a parameter and not a real input
551 (parentNode->kind() == onnx::Constant && !parentNode->mustBeNone() &&
552 parentNode->kindOf(attr::value) ==
553 AttributeKind::t); // Check other types?
554 }
555
hasParamInput(Node * n,const ValueToParamPairMap & valsToParamsMap)556 bool hasParamInput(Node* n, const ValueToParamPairMap& valsToParamsMap) {
557 for (auto input : n->inputs()) {
558 if (valsToParamsMap.find(input) != valsToParamsMap.end()) {
559 return true;
560 }
561 }
562 return false;
563 }
564
getValues(Node * node,const ValueToParamPairMap & valsToParamsMap)565 std::vector<at::Tensor> getValues(
566 Node* node,
567 const ValueToParamPairMap& valsToParamsMap) {
568 size_t numInputs = node->inputs().size();
569 std::vector<at::Tensor> inputTensorValues;
570 inputTensorValues.reserve(numInputs);
571 for (auto val : node->inputs()) {
572 if (val->node()->kind() == prim::Param) {
573 auto itr = valsToParamsMap.find(val);
574 if (itr == valsToParamsMap.end()) {
575 throw std::runtime_error(
576 "getValues: Input value not found amongst constant parameters.");
577 }
578 inputTensorValues.push_back(itr->second.second.toTensor());
579 } else if (val->node()->kind() == onnx::Constant) {
580 inputTensorValues.push_back(val->node()->t(attr::value));
581 } else {
582 throw std::runtime_error(
583 "getValues: Unsupported kind of constant node found.");
584 }
585 }
586 TORCH_INTERNAL_ASSERT(inputTensorValues.size() == numInputs);
587 return inputTensorValues;
588 }
589
areNodeInputsConstant(Node * node,const ValueToParamPairMap & valsToParamsMap)590 bool areNodeInputsConstant(
591 Node* node,
592 const ValueToParamPairMap& valsToParamsMap) {
593 return std::all_of(
594 node->inputs().begin(),
595 node->inputs().end(),
596 [&valsToParamsMap](Value* v) { return isConstant(v, valsToParamsMap); });
597 }
598
getOnnxConstParentsToRemove(Node * node)599 std::vector<Node*> getOnnxConstParentsToRemove(Node* node) {
600 std::vector<Node*> parentNodes;
601 for (auto val : node->inputs()) {
602 // If the parent of 'node' is an onnx::Constant node,
603 // and 'node' is the only downstream node it serves (this
604 // is important), then push it in the list to remove.
605 if (val->node()->kind() == onnx::Constant && val->uses().size() == 1) {
606 parentNodes.push_back(val->node());
607 }
608 }
609 return parentNodes;
610 }
611
612 } // namespace onnx_constant_fold
613
614 // This method updates the block in-place to fold all the one-time
615 // constant-based computations/ops into an initializer node.
616 //
617 // NB: This is not constant folding in the traditional sense, as we
618 // don't try particularly hard to evaluate operations on constant nodes.
619 // This is more of a partial evaluation analysis, where operations on constant
620 // nodes can be lifted so we run them earlier, before the usual parameters are
621 // known.
ConstantFoldONNX(Block * b,ParamMap & paramsDict,int opset_version)622 void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) {
623 if (opset_version < ONNX_OPSET_9) {
624 TORCH_WARN(
625 "Constant folding supported for only opsets >= 9. "
626 "Constant folding not applied.");
627 return;
628 }
629 TORCH_INTERNAL_ASSERT(b->param_node());
630 auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
631 // Only the root block is constant-folded. Folding nested blocks is
632 // not supported for now.
633 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
634 auto node = *it;
635 if (node->outputs().size() > 1) {
636 // Constant folding for multiple-output nodes not supported. Skip it.
637 continue;
638 }
639 if (!onnx_constant_fold::areNodeInputsConstant(node, valsToParamsMap)) {
640 // If all the inputs to this node are not either parameter or
641 // onnx::Constant, then skip this node.
642 continue;
643 }
644
645 auto inputTensorValues =
646 onnx_constant_fold::getValues(node, valsToParamsMap);
647 if (inputTensorValues.empty()) {
648 // This is a terminal node with no inputs, such as onnx::Constant. Skip
649 // it.
650 continue;
651 }
652 auto updatedValWrapped = onnx_constant_fold::runTorchBackendForOnnx(
653 node, inputTensorValues, opset_version);
654 if (updatedValWrapped == std::nullopt) {
655 // Constant folding is not supported for this op. Skip it.
656 continue;
657 }
658
659 at::Tensor updatedVal = *updatedValWrapped;
660 auto newSourceNodeOutput = [&]() -> Value* {
661 if (onnx_constant_fold::hasParamInput(node, valsToParamsMap)) {
662 // Create a new input to the block (prim::Param node output). Add a
663 // corresponding entry in valToParamMap. Replace the downstream inputs
664 // with this value, and disconnect all the input values of the folded
665 // node.
666 auto newSourceNodeOutput = b->addInput();
667 valsToParamsMap.insert(
668 {newSourceNodeOutput,
669 std::make_pair(newSourceNodeOutput->debugName(), updatedVal)});
670 return newSourceNodeOutput;
671 } else {
672 auto newSourceNode =
673 createONNXConstant(node->owningGraph(), node, updatedVal);
674 newSourceNode->copyMetadata(node);
675 return newSourceNode->output();
676 }
677 }();
678 newSourceNodeOutput->inferTypeFrom(updatedVal);
679 node->outputs().at(0)->replaceAllUsesWith(newSourceNodeOutput);
680 // Next we remove the current node that has been replaced by
681 // an initializer. But before we start de-wiring this node,
682 // we check if any parents of this nodes were onnx::Constant
683 // and remove them first, and then remove the current node.
684 // If the parent was an initializer (not onnx::Constant) then
685 // they are all removed by the eraseUnusedBlockInputs() call
686 // (below) outside the loop.
687 auto onnxConstParents =
688 onnx_constant_fold::getOnnxConstParentsToRemove(node);
689 node->removeAllInputs();
690 for (auto* n : onnxConstParents) {
691 n->destroy();
692 }
693 it.destroyCurrent();
694 }
695 eraseUnusedValuesFromMap(valsToParamsMap);
696 eraseUnusedBlockInputs(b);
697 buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
698 return;
699 }
700
ConstantFoldONNX(std::shared_ptr<Graph> & g,ParamMap & paramsDict,int opset_version)701 void ConstantFoldONNX(
702 std::shared_ptr<Graph>& g,
703 ParamMap& paramsDict,
704 int opset_version) {
705 ConstantFoldONNX(g->block(), paramsDict, opset_version);
706 GRAPH_DUMP("After ConstantFoldONNX:", g);
707 }
708
709 } // namespace torch::jit
710