xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
2 #include <torch/csrc/jit/passes/onnx/helper.h>
3 
4 namespace torch::jit {
5 namespace onnx {
6 using namespace ::c10::onnx;
7 }
8 
9 // For ONNX opset < 9, constant operator supports only three data types:
10 // float16, float, and double. Constants of other data types are exported as
11 // float or double and then cast back to their original data type with a cast
12 // node. The above transformation is done in this pass. The motivation behind
13 // having it as a post process pass opposed to handling in symbolic, is that
14 // many constant operators would have already been removed in the export before
15 // this step. On the other hand if cast is inserted in symbolic, subsequent node
16 // conversion will break if it depends on certain inputs being constant.
CastAllConstantToFloating(Block * block)17 void CastAllConstantToFloating(Block* block) {
18   auto graph = block->owningGraph();
19   auto it = block->nodes().begin();
20   while (it != block->nodes().end()) {
21     auto node = *it;
22     ++it;
23     for (auto block : node->blocks()) {
24       CastAllConstantToFloating(block);
25     }
26 
27     if (node->kind() == onnx::Constant) {
28       auto val = node->t(attr::value);
29       at::ScalarType dtype = val.scalar_type();
30       auto val_type = TensorType::create(val);
31       if (dtype != at::ScalarType::Double && dtype != at::ScalarType::Float &&
32           dtype != at::ScalarType::Half) {
33         int to_type = 0;
34         switch (val.scalar_type()) {
35           case at::ScalarType::Byte:
36           case at::ScalarType::Char:
37           case at::ScalarType::Int:
38           case at::ScalarType::Short:
39           case at::ScalarType::Bool:
40             to_type = ATenTypeToOnnxType(val.scalar_type());
41             val = val.to(at::ScalarType::Float);
42             break;
43 
44           case at::ScalarType::Long:
45             to_type = ATenTypeToOnnxType(val.scalar_type());
46             val = val.to(at::ScalarType::Double);
47             break;
48 
49           default:
50             throw std::runtime_error("Unsupported types: complex, string");
51         }
52         // create a cast node
53         node->removeAttribute(attr::value);
54         node->t_(attr::value, val);
55         Node* cast_node = graph->create(onnx::Cast, 1);
56         cast_node->i_(attr::to, to_type);
57         cast_node->output()->setType(val_type);
58         cast_node->insertAfter(node);
59         // get input from cast node
60         node->outputs().at(0)->replaceAllUsesWith(cast_node->outputs().at(0));
61         // add input from constant to cast node
62         cast_node->addInput(node->outputs().at(0));
63         cast_node->copyMetadata(node);
64       }
65     }
66   }
67 }
68 
CastAllConstantToFloating(const std::shared_ptr<Graph> & graph)69 void CastAllConstantToFloating(const std::shared_ptr<Graph>& graph) {
70   CastAllConstantToFloating(graph->block());
71 }
72 } // namespace torch::jit
73