1 #include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h> 2 #include <torch/csrc/jit/passes/onnx/helper.h> 3 4 namespace torch::jit { 5 namespace onnx { 6 using namespace ::c10::onnx; 7 } 8 9 // For ONNX opset < 9, constant operator supports only three data types: 10 // float16, float, and double. Constants of other data types are exported as 11 // float or double and then cast back to their original data type with a cast 12 // node. The above transformation is done in this pass. The motivation behind 13 // having it as a post process pass opposed to handling in symbolic, is that 14 // many constant operators would have already been removed in the export before 15 // this step. On the other hand if cast is inserted in symbolic, subsequent node 16 // conversion will break if it depends on certain inputs being constant. CastAllConstantToFloating(Block * block)17void CastAllConstantToFloating(Block* block) { 18 auto graph = block->owningGraph(); 19 auto it = block->nodes().begin(); 20 while (it != block->nodes().end()) { 21 auto node = *it; 22 ++it; 23 for (auto block : node->blocks()) { 24 CastAllConstantToFloating(block); 25 } 26 27 if (node->kind() == onnx::Constant) { 28 auto val = node->t(attr::value); 29 at::ScalarType dtype = val.scalar_type(); 30 auto val_type = TensorType::create(val); 31 if (dtype != at::ScalarType::Double && dtype != at::ScalarType::Float && 32 dtype != at::ScalarType::Half) { 33 int to_type = 0; 34 switch (val.scalar_type()) { 35 case at::ScalarType::Byte: 36 case at::ScalarType::Char: 37 case at::ScalarType::Int: 38 case at::ScalarType::Short: 39 case at::ScalarType::Bool: 40 to_type = ATenTypeToOnnxType(val.scalar_type()); 41 val = val.to(at::ScalarType::Float); 42 break; 43 44 case at::ScalarType::Long: 45 to_type = ATenTypeToOnnxType(val.scalar_type()); 46 val = val.to(at::ScalarType::Double); 47 break; 48 49 default: 50 throw std::runtime_error("Unsupported types: complex, string"); 51 } 52 // create a cast node 53 node->removeAttribute(attr::value); 54 node->t_(attr::value, val); 55 Node* cast_node = graph->create(onnx::Cast, 1); 56 cast_node->i_(attr::to, to_type); 57 cast_node->output()->setType(val_type); 58 cast_node->insertAfter(node); 59 // get input from cast node 60 node->outputs().at(0)->replaceAllUsesWith(cast_node->outputs().at(0)); 61 // add input from constant to cast node 62 cast_node->addInput(node->outputs().at(0)); 63 cast_node->copyMetadata(node); 64 } 65 } 66 } 67 } 68 CastAllConstantToFloating(const std::shared_ptr<Graph> & graph)69void CastAllConstantToFloating(const std::shared_ptr<Graph>& graph) { 70 CastAllConstantToFloating(graph->block()); 71 } 72 } // namespace torch::jit 73