1 #include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h> 2 3 #include <torch/csrc/jit/ir/constants.h> 4 #include <torch/csrc/jit/jit_log.h> 5 6 namespace torch::jit { 7 8 // onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0, 9 // so before converting the ints to tensors we need to cast them to floats. PrepareDivisionForONNXOnBlock(Block * block)10static void PrepareDivisionForONNXOnBlock(Block* block) { 11 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { 12 for (auto sub : it->blocks()) { 13 PrepareDivisionForONNXOnBlock(sub); 14 } 15 WithInsertPoint guard(*it); 16 auto* subgraph = it->owningGraph(); 17 18 if (it->matches("aten::div(int a, int b) -> float")) { 19 // Cast to Float before dividing 20 std::vector<Value*> floattensor_inputs = 21 fmap(it->inputs(), [&](Value* input) { 22 auto* longtensor = 23 subgraph->insertNode(subgraph->createNumToTensor(input)) 24 ->output(); 25 longtensor->node()->copyMetadata(input->node()); 26 auto* nonblocking = subgraph->insertConstant(0); 27 auto* cast = 28 subgraph->create(aten::_cast_Float, {longtensor, nonblocking}); 29 cast->copyMetadata(*it); 30 return subgraph->insertNode(cast)->output(); 31 }); 32 33 it->replaceInput(0, floattensor_inputs[0]); 34 it->replaceInput(1, floattensor_inputs[1]); 35 it->output()->setType(TensorType::fromNumberType(*FloatType::get())); 36 } 37 } 38 } 39 PrepareDivisionForONNX(const std::shared_ptr<Graph> & graph)40void PrepareDivisionForONNX(const std::shared_ptr<Graph>& graph) { 41 PrepareDivisionForONNXOnBlock(graph->block()); 42 GRAPH_DUMP("After PrepareDivisionForONNX: ", graph); 43 } 44 45 } // namespace torch::jit 46