xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/constant_fold.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/passes/onnx/constant_fold.h>
3 #include <torch/csrc/jit/passes/onnx/helper.h>
4 
5 #include <ATen/Functions.h>
6 
7 #include <c10/util/Exception.h>
8 #include <c10/util/irange.h>
9 #include <algorithm>
10 #include <optional>
11 
12 namespace torch::jit {
13 
14 namespace onnx {
15 using namespace ::c10::onnx;
16 }
17 
18 namespace onnx_constant_fold {
19 
20 enum OnnxType : int {
21   ONNX_FLOAT = 1,
22   ONNX_UINT8,
23   ONNX_INT8,
24   ONNX_UINT16,
25   ONNX_INT16,
26   ONNX_INT32,
27   ONNX_INT64,
28   ONNX_FLOAT16 = 10,
29   ONNX_DOUBLE,
30   ONNX_UINT32,
31 };
32 
33 std::unordered_map<int, at::ScalarType> onnxTypeToScalarTypeMap = {
34     // Only conversion of ONNX numeric types is included here.
35     // Unsigned ONNX types are mapped to the next higher signed
36     // ScalarType type.
37     {ONNX_FLOAT, at::kFloat},
38     {ONNX_UINT8, at::kByte},
39     {ONNX_INT8, at::kChar},
40     {ONNX_UINT16, at::kInt},
41     {ONNX_INT16, at::kShort},
42     {ONNX_INT32, at::kInt},
43     {ONNX_INT64, at::kLong},
44     {ONNX_FLOAT16, at::kFloat},
45     {ONNX_DOUBLE, at::kDouble},
46     {ONNX_UINT32, at::kLong},
47 };
48 
handleNegativeStartEndIndex(int64_t & start,int64_t & end,int64_t & axis,c10::IntArrayRef tensorSizes)49 void handleNegativeStartEndIndex(
50     int64_t& start,
51     int64_t& end,
52     int64_t& axis,
53     c10::IntArrayRef tensorSizes) {
54   if (start < 0) {
55     start = tensorSizes[axis] + start;
56   }
57   if (end < 0) {
58     end = tensorSizes[axis] + end;
59   }
60   // index higher than dimension is treated as the end.
61   if (end > tensorSizes[axis]) {
62     end = tensorSizes[axis];
63   }
64 }
65 
runTorchSlice_opset9(const Node * node,std::vector<at::Tensor> & inputTensorValues)66 std::optional<at::Tensor> runTorchSlice_opset9(
67     const Node* node,
68     std::vector<at::Tensor>& inputTensorValues) {
69   assert(inputTensorValues.size() == 1);
70   if (inputTensorValues.size() != 1) {
71     TORCH_WARN(
72         "Constant folding - Invalid number of inputs found for opset 9 "
73         "onnx::Slice op. Constant folding not applied.");
74     return std::nullopt;
75   }
76   if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) {
77     return std::nullopt;
78   }
79   auto startsAttr = node->is(attr::starts);
80   auto endsAttr = node->is(attr::ends);
81   if (startsAttr.size() != endsAttr.size()) {
82     return std::nullopt;
83   }
84   std::vector<int64_t> axesAttr;
85   if (node->hasAttributeS("axes")) {
86     axesAttr = node->is(attr::axes);
87   } else {
88     axesAttr.resize(startsAttr.size());
89     std::iota(axesAttr.begin(), axesAttr.end(), 0);
90   }
91   auto updated_val = inputTensorValues[0];
92   for (const auto i : c10::irange(axesAttr.size())) {
93     // ONNX slice accepts negative starts and ends values.
94     int64_t axis = axesAttr[i], start = startsAttr[i], end = endsAttr[i];
95     // ONNX slice accepts negative axis, fix this for aten op
96     axis += axis < 0 ? inputTensorValues[0].sizes().size() : 0;
97     handleNegativeStartEndIndex(start, end, axis, updated_val.sizes());
98     int64_t length = end - start;
99     if (length < 0 || start > updated_val.sizes()[axis] - length)
100       return std::nullopt;
101     updated_val = at::narrow(updated_val, axis, start, length);
102   }
103   return std::optional<at::Tensor>(updated_val);
104 }
105 
runTorchSlice_opset10(const Node * node,std::vector<at::Tensor> & inputTensorValues)106 std::optional<at::Tensor> runTorchSlice_opset10(
107     const Node* node,
108     std::vector<at::Tensor>& inputTensorValues) {
109   const int maxSliceInputCount = 5;
110   const int minSliceInputCount = 3;
111   if (inputTensorValues.size() < minSliceInputCount ||
112       inputTensorValues.size() > maxSliceInputCount) {
113     TORCH_WARN(
114         "Constant folding - Invalid number of inputs found for opset opset >= 10 onnx::Slice op. "
115         "Constant folding not applied.");
116     return std::nullopt;
117   }
118   // Checking validity of 'starts' and 'ends' input
119   if (inputTensorValues[1].sizes().size() != 1 ||
120       inputTensorValues[2].sizes().size() != 1) {
121     TORCH_WARN(
122         "Constant folding - Invalid 'starts' or 'ends' inputs found for opset >= 10 onnx::Slice op. "
123         "Constant folding not applied.");
124     return std::nullopt;
125   }
126   if (inputTensorValues[1].sizes()[0] != inputTensorValues[2].sizes()[0]) {
127     // Number of elements of 'starts' and 'ends' 1-D input tensors should be the
128     // same
129     return std::nullopt;
130   }
131   // Checking 'axes' input, if available.
132   std::vector<int64_t> axes;
133   if (inputTensorValues.size() > 3) {
134     if (inputTensorValues[3].sizes().size() != 1) {
135       TORCH_WARN(
136           "Constant folding - Invalid 'axes' input found for opset >= 10 onnx::Slice op. "
137           "Constant folding not applied.");
138       return std::nullopt;
139     }
140     if (inputTensorValues[3].sizes()[0] != inputTensorValues[1].sizes()[0]) {
141       // Number of elements of 'axes' and 'ends' 1-D input tensors should be the
142       // same
143       TORCH_WARN(
144           "Constant folding - Invalid 'axes' or 'ends' inputs found for opset >= 10 onnx::Slice op. "
145           "Constant folding not applied.");
146       return std::nullopt;
147     }
148     auto axes_a = inputTensorValues[3].accessor<int64_t, 1>();
149     axes.resize(inputTensorValues[3].sizes()[0]);
150     // ONNX slice accepts negative axis, fix this for aten op
151     for (const auto i : c10::irange(inputTensorValues[3].sizes()[0])) {
152       axes[i] = axes_a[i] < 0 ? axes_a[i] + inputTensorValues[0].sizes().size()
153                               : axes_a[i];
154     }
155   } else {
156     axes = std::vector<int64_t>(inputTensorValues[1].sizes()[0], 0);
157   }
158   // Checking 'steps' input, if available.
159   if (inputTensorValues.size() > 4) {
160     if (inputTensorValues[4].sizes().size() != 1) {
161       TORCH_WARN(
162           "Constant folding - Invalid 'steps' input found for opset >= 10 onnx::Slice op. "
163           "Constant folding not applied.");
164       return std::nullopt;
165     }
166     if (inputTensorValues[4].sizes()[0] != inputTensorValues[1].sizes()[0]) {
167       // Number of elements of 'steps' and 'ends' 1-D input tensors should be
168       // the same
169       TORCH_WARN(
170           "Constant folding - Invalid 'steps' or 'ends' inputs found for opset >= 10 onnx::Slice op. "
171           "Constant folding not applied.");
172       return std::nullopt;
173     }
174     auto steps_a = inputTensorValues[4].accessor<int64_t, 1>();
175     for (const auto i : c10::irange(inputTensorValues[4].sizes()[0])) {
176       // Only steps == 1 are supported for constant-folding.
177       if (steps_a[i] != 1) {
178         TORCH_WARN(
179             "Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. "
180             "Constant folding not applied.");
181         return std::nullopt;
182       }
183     }
184   }
185   auto starts_a = inputTensorValues[1].accessor<int64_t, 1>();
186   auto ends_a = inputTensorValues[2].accessor<int64_t, 1>();
187   auto updated_val = inputTensorValues[0];
188   for (const auto i : c10::irange(inputTensorValues[1].sizes()[0])) {
189     // ONNX slice accepts negative starts and ends values.
190     int64_t start = starts_a[i], end = ends_a[i], axis = axes[i];
191     handleNegativeStartEndIndex(start, end, axis, updated_val.sizes());
192     int64_t length = end - start;
193     if (length < 0 || start > updated_val.sizes()[axis] - length)
194       return std::nullopt;
195     updated_val = at::narrow(updated_val, axis, start, length);
196   }
197   return std::optional<at::Tensor>(updated_val);
198 }
199 
200 // Refer to AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF
runTorchArange_opset11(const Node * node,const std::vector<at::Tensor> & inputTensorValues)201 at::Tensor runTorchArange_opset11(
202     const Node* node,
203     const std::vector<at::Tensor>& inputTensorValues) {
204   TORCH_INTERNAL_ASSERT(inputTensorValues.size() == 3);
205   auto dtype = inputTensorValues[0].scalar_type();
206   at::Tensor updated_val;
207   switch (dtype) {
208     case at::ScalarType::Float: {
209       auto start = inputTensorValues[0].item<float>();
210       auto end = inputTensorValues[1].item<float>();
211       auto step = inputTensorValues[2].item<float>();
212       updated_val = at::arange(start, end, step);
213       break;
214     }
215     case at::ScalarType::Double: {
216       auto start = inputTensorValues[0].item<double>();
217       auto end = inputTensorValues[1].item<double>();
218       auto step = inputTensorValues[2].item<double>();
219       updated_val = at::arange(start, end, step);
220       break;
221     }
222     case at::ScalarType::Short: {
223       auto start = inputTensorValues[0].item<int16_t>();
224       auto end = inputTensorValues[1].item<int16_t>();
225       auto step = inputTensorValues[2].item<int16_t>();
226       updated_val = at::arange(start, end, step);
227       break;
228     }
229     case at::ScalarType::Int: {
230       auto start = inputTensorValues[0].item<int>();
231       auto end = inputTensorValues[1].item<int>();
232       auto step = inputTensorValues[2].item<int>();
233       updated_val = at::arange(start, end, step);
234       break;
235     }
236     case at::ScalarType::Long: {
237       auto start = inputTensorValues[0].item<int64_t>();
238       auto end = inputTensorValues[1].item<int64_t>();
239       auto step = inputTensorValues[2].item<int64_t>();
240       updated_val = at::arange(start, end, step);
241       break;
242     }
243     default: {
244       TORCH_WARN(
245           "Constant folding - ONNX Range type: ", dtype, " is not supported.");
246     }
247   }
248   return updated_val;
249 }
250 
IntToTensor(int64_t value)251 at::Tensor IntToTensor(int64_t value) {
252   auto options = c10::TensorOptions().dtype(at::kLong).device(at::kCPU);
253   std::vector<int64_t> size_data = {value};
254   auto f = at::from_blob(size_data.data(), {1}, at::kLong).to(at::kCPU);
255   // Need copy here
256   at::Tensor f_copy = at::empty({1}, options);
257   f_copy.copy_(f);
258   return at::squeeze(f_copy, 0);
259 }
260 
runTorchBackendForOnnx(const Node * node,std::vector<at::Tensor> & inputTensorValues,int opset_version)261 std::optional<at::Tensor> runTorchBackendForOnnx(
262     const Node* node,
263     std::vector<at::Tensor>& inputTensorValues,
264     int opset_version) {
265   at::Tensor updated_val;
266   if (node->kind() == onnx::Slice) {
267     if (opset_version == ONNX_OPSET_9) {
268       return runTorchSlice_opset9(node, inputTensorValues);
269     } else if (opset_version >= ONNX_OPSET_10) {
270       return runTorchSlice_opset10(node, inputTensorValues);
271     } else {
272       TORCH_WARN(
273           "Constant folding - unsupported opset version. Constant folding not applied.");
274       return std::nullopt;
275     }
276   } else if (node->kind() == onnx::Concat) {
277     if (!node->hasAttributeS("axis")) {
278       return std::nullopt;
279     }
280     updated_val =
281         at::cat(at::TensorList(inputTensorValues), node->i(attr::axis));
282     return std::optional<at::Tensor>(updated_val);
283   } else if (node->kind() == onnx::Sqrt) {
284     updated_val = at::sqrt(inputTensorValues[0]);
285     return std::optional<at::Tensor>(updated_val);
286   } else if (node->kind() == onnx::Div) {
287     // One example shows at::div(CPULongType, CPULongType) = CPUFloatType,
288     // So we add a cast below.
289     updated_val = at::div(inputTensorValues[0], inputTensorValues[1]);
290     if (inputTensorValues[0].scalar_type() ==
291         inputTensorValues[1].scalar_type()) {
292       updated_val = updated_val.to(inputTensorValues[0].scalar_type());
293     }
294     return std::optional<at::Tensor>(updated_val);
295   } else if (node->kind() == onnx::Mul) {
296     updated_val = at::mul(inputTensorValues[0], inputTensorValues[1]);
297     return std::optional<at::Tensor>(updated_val);
298   } else if (node->kind() == onnx::Sub) {
299     updated_val = at::sub(inputTensorValues[0], inputTensorValues[1]);
300     return std::optional<at::Tensor>(updated_val);
301   } else if (node->kind() == onnx::Add) {
302     updated_val = at::add(inputTensorValues[0], inputTensorValues[1]);
303     return std::optional<at::Tensor>(updated_val);
304   } else if (node->kind() == onnx::Unsqueeze) {
305     if (opset_version >= ONNX_OPSET_13) {
306       assert(inputTensorValues.size() == 2);
307       // Checking validity of 'axes' input
308       if (inputTensorValues[1].sizes().size() != 1) {
309         TORCH_WARN(
310             "Constant folding - Invalid 'axes' inputs found for opset 13 onnx::Unsqueeze op. "
311             "Constant folding not applied.");
312         return std::nullopt;
313       }
314       auto axes_a = inputTensorValues[1].accessor<int64_t, 1>();
315       std::vector<int64_t> axes;
316       for (int64_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
317         // ONNX unsqueeze accepts negative axes
318         // From https://pytorch.org/docs/stable/generated/torch.unsqueeze.html
319         // Negative dim will correspond to unsqueeze() applied at dim = dim +
320         // input.dim() + 1.
321         axes_a[i] +=
322             axes_a[i] < 0 ? inputTensorValues[0].sizes().size() + 1 : 0;
323         axes.push_back(axes_a[i]);
324       }
325       std::sort(axes.begin(), axes.end());
326       updated_val = inputTensorValues[0];
327       for (int64_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
328         updated_val = at::unsqueeze(updated_val, axes[i]);
329       }
330       return std::optional<at::Tensor>(updated_val);
331     } else if (opset_version >= ONNX_OPSET_9) {
332       assert(inputTensorValues.size() == 1);
333       if (!node->hasAttributeS("axes")) {
334         return std::nullopt;
335       }
336       updated_val = inputTensorValues[0];
337       std::vector<int64_t> axesAttr = node->is(attr::axes);
338       std::sort(axesAttr.begin(), axesAttr.end());
339       for (auto axis : axesAttr) {
340         updated_val = at::unsqueeze(updated_val, axis);
341       }
342       return std::optional<at::Tensor>(updated_val);
343     } else {
344       TORCH_WARN(
345           "Constant folding - unsupported opset version. "
346           "Constant folding not applied.");
347       return std::nullopt;
348     }
349   } else if (node->kind() == onnx::Squeeze) {
350     assert(inputTensorValues.size() == 2 || inputTensorValues.size() == 1);
351     if (opset_version >= ONNX_OPSET_13) {
352       // Squeeze version 13 input axes is optional, inputTensorValues.size() ==
353       // 1 means axes equal to None
354       updated_val = inputTensorValues[0];
355       if (inputTensorValues.size() == 2) {
356         // Checking validity of 'axes' input
357         if (inputTensorValues[1].sizes().size() != 1) {
358           TORCH_WARN(
359               "Constant folding - Invalid 'axes' inputs found for opset 13 onnx::Squeeze op. "
360               "Constant folding not applied.");
361           return std::nullopt;
362         }
363         auto axes_a = inputTensorValues[1].accessor<int64_t, 1>();
364         std::vector<int64_t> axes;
365         for (int64_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
366           // ONNX Squeeze accepts negative axes
367           axes_a[i] += axes_a[i] < 0 ? inputTensorValues[0].sizes().size() : 0;
368           axes.push_back(axes_a[i]);
369         }
370         std::sort(axes.begin(), axes.end());
371         for (int64_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
372           updated_val = at::squeeze(updated_val, axes[i]);
373         }
374       }
375       return std::optional<at::Tensor>(updated_val);
376     } else if (opset_version >= ONNX_OPSET_9) {
377       assert(inputTensorValues.size() == 1);
378       updated_val = inputTensorValues[0];
379       if (node->hasAttributeS("axes")) {
380         std::vector<int64_t> axesAttr = node->is(attr::axes);
381         std::sort(axesAttr.begin(), axesAttr.end());
382         for (auto axis : axesAttr) {
383           updated_val = at::squeeze(updated_val, axis);
384         }
385       }
386       return std::optional<at::Tensor>(updated_val);
387     } else {
388       TORCH_WARN(
389           "Constant folding - unsupported opset version. "
390           "Constant folding not applied.");
391       return std::nullopt;
392     }
393   } else if (node->kind() == onnx::Transpose) {
394     assert(inputTensorValues.size() == 1);
395     if (!node->hasAttributeS("perm")) {
396       return std::nullopt;
397     }
398     updated_val = inputTensorValues[0].permute(node->is(attr::perm));
399     return std::optional<at::Tensor>(updated_val);
400   } else if (node->kind() == onnx::Cast) {
401     assert(inputTensorValues.size() == 1);
402     if (node->hasAttributeS("to") && ONNXTypeToATenType(node->i(attr::to))) {
403       updated_val = inputTensorValues[0].to(
404           ONNXTypeToATenType(node->i(attr::to)).value());
405       return std::optional<at::Tensor>(updated_val);
406     }
407     return std::nullopt;
408   } else if (node->kind() == onnx::Reshape) {
409     assert(inputTensorValues.size() == 2);
410     updated_val = inputTensorValues[0];
411     std::vector<int64_t> shape(inputTensorValues[1].sizes()[0], 0);
412     auto shape_a = inputTensorValues[1].accessor<int64_t, 1>();
413     assert(inputTensorValues[1].sizes()[0] >= 0);
414     // Set value of allowzero
415     int64_t allowzero = 0;
416     if (node->hasAttributeS("allowzero")) {
417       allowzero = node->i(attr::allowzero);
418     }
419     for (size_t i = 0; i < (size_t)(inputTensorValues[1].sizes()[0]); ++i) {
420       // All shape dim values should be >= -1
421       // onnx::Reshape supports a shape dim value to be zero, in
422       // which case the actual dim value remains unchanged. However,
423       // at::reshape does not support shape dim value to be zero
424       assert(shape_a[i] >= -1);
425       if (shape_a[i] == 0 && !allowzero) {
426         if (i >= inputTensorValues[0].sizes().size()) {
427           throw std::runtime_error(
428               "Dimension with value 0 exceeds the input size dimensions.");
429         }
430         shape[i] = inputTensorValues[0].sizes()[i];
431       } else {
432         shape[i] = shape_a[i];
433       }
434     }
435     return std::optional<at::Tensor>(at::reshape(updated_val, shape));
436   } else if (node->kind() == onnx::Shape) {
437     TORCH_INTERNAL_ASSERT(inputTensorValues.size() == 1);
438     updated_val = at::_shape_as_tensor(inputTensorValues[0]);
439     return std::optional<at::Tensor>(updated_val);
440   } else if (node->kind() == onnx::ReduceL1 || node->kind() == onnx::ReduceL2) {
441     assert(inputTensorValues.size() == 1);
442     if (!node->hasAttributeS("axes")) {
443       return std::nullopt;
444     }
445     if (!node->hasAttributeS("keepdims")) {
446       return std::nullopt;
447     }
448     int p = node->kind() == onnx::ReduceL1 ? 1 : 2;
449     updated_val = at::norm(
450         inputTensorValues[0], p, node->is(attr::axes), node->i(attr::keepdims));
451     return std::optional<at::Tensor>(updated_val);
452   } else if (node->kind() == onnx::ReduceProd) {
453     int64_t rank = inputTensorValues[0].sizes().size();
454     std::vector<int64_t> axes;
455     if (!node->hasAttributeS("axes")) {
456       axes = std::vector<int64_t>(rank);
457       std::iota(axes.rbegin(), axes.rend(), 0);
458     } else {
459       for (const auto& axis : node->is(attr::axes)) {
460         axes.emplace_back(axis < 0 ? axis + rank : axis);
461       }
462       std::sort(axes.begin(), axes.end(), std::greater<>());
463     }
464 
465     bool keepdims =
466         node->hasAttributeS("keepdims") ? node->i(attr::keepdims) : true;
467     updated_val = inputTensorValues[0];
468     for (const auto& axis : axes) {
469       updated_val = at::prod(updated_val, axis, keepdims);
470     }
471     return std::optional<at::Tensor>(updated_val);
472   } else if (node->kind() == onnx::Gather) {
473     assert(inputTensorValues.size() == 2);
474     // default axis = 0
475     int64_t axis = 0;
476     if (node->hasAttributeS("axis")) {
477       axis = node->i(attr::axis);
478     }
479     // If axis attribute for onnx::Gather has a value less than 0,
480     // It needs to be adjusted (+= dim sizes) for aten op
481     axis += axis < 0 ? inputTensorValues[0].sizes().size() : 0;
482     at::Tensor indices = inputTensorValues[1];
483     auto q = indices.dim();
484     // at::index_select only supports indices with rank <= 1.
485     // See https://pytorch.org/docs/main/generated/torch.index_select.html
486     if (q > 1) {
487       return std::nullopt;
488     }
489     // If the device of indices tensor is not the same with it of the input
490     // tensor, move it to the device of the input tensor
491     if (inputTensorValues[0].device() != indices.device()) {
492       indices = indices.to(inputTensorValues[0].device());
493     }
494     // If indices input for onnx::Gather has a value less than 0,
495     // It needs to be adjusted (+= dim value) for aten op
496     auto less_mask = at::lt(indices, 0);
497     auto indices_corr = at::add(indices, inputTensorValues[0].sizes()[axis]);
498     auto indices_masked = at::where(less_mask, indices_corr, indices);
499     updated_val = at::index_select(inputTensorValues[0], axis, indices_masked);
500     // If rank of indices is 0, rank of output tensor should be
501     // rank_of_input - 1.
502     if (q < 1) {
503       updated_val = updated_val.squeeze(axis);
504     }
505     return std::optional<at::Tensor>(updated_val);
506   } else if (node->kind() == onnx::Range) {
507     updated_val = runTorchArange_opset11(node, inputTensorValues);
508     return std::optional<at::Tensor>(updated_val);
509   } else if (node->kind() == onnx::Where) {
510     updated_val = at::where(
511         inputTensorValues[0], inputTensorValues[1], inputTensorValues[2]);
512     return std::optional<at::Tensor>(updated_val);
513   } else if (node->kind() == onnx::Equal) {
514     updated_val = at::eq(inputTensorValues[0], inputTensorValues[1]);
515     return std::optional<at::Tensor>(updated_val);
516   } else if (node->kind() == onnx::Greater) {
517     updated_val = at::greater(inputTensorValues[0], inputTensorValues[1]);
518     return std::optional<at::Tensor>(updated_val);
519   } else if (node->kind() == onnx::Less) {
520     updated_val = at::less(inputTensorValues[0], inputTensorValues[1]);
521     return std::optional<at::Tensor>(updated_val);
522   } else if (node->kind() == onnx::Neg) {
523     updated_val = at::neg(inputTensorValues[0]);
524     return std::optional<at::Tensor>(updated_val);
525   } else if (node->kind() == onnx::Not) {
526     auto ones =
527         at::ones(inputTensorValues[0].sizes(), inputTensorValues[0].dtype());
528     updated_val = at::ne(inputTensorValues[0], ones);
529     return std::optional<at::Tensor>(updated_val);
530   } else if (node->kind() == onnx::Size) {
531     int64_t total_size = 1;
532     for (auto size : inputTensorValues[0].sizes()) {
533       total_size *= size;
534     }
535     return std::optional<at::Tensor>(IntToTensor(total_size));
536   } else if (node->kind() == onnx::Softmax) {
537     int64_t axis = node->hasAttributeS("axis") ? node->i(attr::axis) : -1;
538     updated_val = at::softmax(inputTensorValues[0], axis);
539     return std::optional<at::Tensor>(updated_val);
540   } else {
541     return std::nullopt;
542   }
543 }
544 
isConstant(Value * val,const ValueToParamPairMap & valsToParamsMap)545 bool isConstant(Value* val, const ValueToParamPairMap& valsToParamsMap) {
546   auto parentNode = val->node();
547   return (parentNode->kind() == prim::Param &&
548           valsToParamsMap.find(val) !=
549               valsToParamsMap
550                   .end()) || // Checks val is a parameter and not a real input
551       (parentNode->kind() == onnx::Constant && !parentNode->mustBeNone() &&
552        parentNode->kindOf(attr::value) ==
553            AttributeKind::t); // Check other types?
554 }
555 
hasParamInput(Node * n,const ValueToParamPairMap & valsToParamsMap)556 bool hasParamInput(Node* n, const ValueToParamPairMap& valsToParamsMap) {
557   for (auto input : n->inputs()) {
558     if (valsToParamsMap.find(input) != valsToParamsMap.end()) {
559       return true;
560     }
561   }
562   return false;
563 }
564 
getValues(Node * node,const ValueToParamPairMap & valsToParamsMap)565 std::vector<at::Tensor> getValues(
566     Node* node,
567     const ValueToParamPairMap& valsToParamsMap) {
568   size_t numInputs = node->inputs().size();
569   std::vector<at::Tensor> inputTensorValues;
570   inputTensorValues.reserve(numInputs);
571   for (auto val : node->inputs()) {
572     if (val->node()->kind() == prim::Param) {
573       auto itr = valsToParamsMap.find(val);
574       if (itr == valsToParamsMap.end()) {
575         throw std::runtime_error(
576             "getValues: Input value not found amongst constant parameters.");
577       }
578       inputTensorValues.push_back(itr->second.second.toTensor());
579     } else if (val->node()->kind() == onnx::Constant) {
580       inputTensorValues.push_back(val->node()->t(attr::value));
581     } else {
582       throw std::runtime_error(
583           "getValues: Unsupported kind of constant node found.");
584     }
585   }
586   TORCH_INTERNAL_ASSERT(inputTensorValues.size() == numInputs);
587   return inputTensorValues;
588 }
589 
areNodeInputsConstant(Node * node,const ValueToParamPairMap & valsToParamsMap)590 bool areNodeInputsConstant(
591     Node* node,
592     const ValueToParamPairMap& valsToParamsMap) {
593   return std::all_of(
594       node->inputs().begin(),
595       node->inputs().end(),
596       [&valsToParamsMap](Value* v) { return isConstant(v, valsToParamsMap); });
597 }
598 
getOnnxConstParentsToRemove(Node * node)599 std::vector<Node*> getOnnxConstParentsToRemove(Node* node) {
600   std::vector<Node*> parentNodes;
601   for (auto val : node->inputs()) {
602     // If the parent of 'node' is an onnx::Constant node,
603     // and 'node' is the only downstream node it serves (this
604     // is important), then push it in the list to remove.
605     if (val->node()->kind() == onnx::Constant && val->uses().size() == 1) {
606       parentNodes.push_back(val->node());
607     }
608   }
609   return parentNodes;
610 }
611 
612 } // namespace onnx_constant_fold
613 
614 // This method updates the block in-place to fold all the one-time
615 // constant-based computations/ops into an initializer node.
616 //
617 // NB: This is not constant folding in the traditional sense, as we
618 // don't try particularly hard to evaluate operations on constant nodes.
619 // This is more of a partial evaluation analysis, where operations on constant
620 // nodes can be lifted so we run them earlier, before the usual parameters are
621 // known.
ConstantFoldONNX(Block * b,ParamMap & paramsDict,int opset_version)622 void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) {
623   if (opset_version < ONNX_OPSET_9) {
624     TORCH_WARN(
625         "Constant folding supported for only opsets >= 9. "
626         "Constant folding not applied.");
627     return;
628   }
629   TORCH_INTERNAL_ASSERT(b->param_node());
630   auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
631   // Only the root block is constant-folded. Folding nested blocks is
632   // not supported for now.
633   for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
634     auto node = *it;
635     if (node->outputs().size() > 1) {
636       // Constant folding for multiple-output nodes not supported. Skip it.
637       continue;
638     }
639     if (!onnx_constant_fold::areNodeInputsConstant(node, valsToParamsMap)) {
640       // If all the inputs to this node are not either parameter or
641       // onnx::Constant, then skip this node.
642       continue;
643     }
644 
645     auto inputTensorValues =
646         onnx_constant_fold::getValues(node, valsToParamsMap);
647     if (inputTensorValues.empty()) {
648       // This is a terminal node with no inputs, such as onnx::Constant. Skip
649       // it.
650       continue;
651     }
652     auto updatedValWrapped = onnx_constant_fold::runTorchBackendForOnnx(
653         node, inputTensorValues, opset_version);
654     if (updatedValWrapped == std::nullopt) {
655       // Constant folding is not supported for this op. Skip it.
656       continue;
657     }
658 
659     at::Tensor updatedVal = *updatedValWrapped;
660     auto newSourceNodeOutput = [&]() -> Value* {
661       if (onnx_constant_fold::hasParamInput(node, valsToParamsMap)) {
662         // Create a new input to the block (prim::Param node output). Add a
663         // corresponding entry in valToParamMap. Replace the downstream inputs
664         // with this value, and disconnect all the input values of the folded
665         // node.
666         auto newSourceNodeOutput = b->addInput();
667         valsToParamsMap.insert(
668             {newSourceNodeOutput,
669              std::make_pair(newSourceNodeOutput->debugName(), updatedVal)});
670         return newSourceNodeOutput;
671       } else {
672         auto newSourceNode =
673             createONNXConstant(node->owningGraph(), node, updatedVal);
674         newSourceNode->copyMetadata(node);
675         return newSourceNode->output();
676       }
677     }();
678     newSourceNodeOutput->inferTypeFrom(updatedVal);
679     node->outputs().at(0)->replaceAllUsesWith(newSourceNodeOutput);
680     // Next we remove the current node that has been replaced by
681     // an initializer. But before we start de-wiring this node,
682     // we check if any parents of this nodes were onnx::Constant
683     // and remove them first, and then remove the current node.
684     // If the parent was an initializer (not onnx::Constant) then
685     // they are all removed by the eraseUnusedBlockInputs() call
686     // (below) outside the loop.
687     auto onnxConstParents =
688         onnx_constant_fold::getOnnxConstParentsToRemove(node);
689     node->removeAllInputs();
690     for (auto* n : onnxConstParents) {
691       n->destroy();
692     }
693     it.destroyCurrent();
694   }
695   eraseUnusedValuesFromMap(valsToParamsMap);
696   eraseUnusedBlockInputs(b);
697   buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
698   return;
699 }
700 
ConstantFoldONNX(std::shared_ptr<Graph> & g,ParamMap & paramsDict,int opset_version)701 void ConstantFoldONNX(
702     std::shared_ptr<Graph>& g,
703     ParamMap& paramsDict,
704     int opset_version) {
705   ConstantFoldONNX(g->block(), paramsDict, opset_version);
706   GRAPH_DUMP("After ConstantFoldONNX:", g);
707 }
708 
709 } // namespace torch::jit
710