xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
2 
3 #include <torch/csrc/jit/ir/constants.h>
4 #include <torch/csrc/jit/jit_log.h>
5 
6 namespace torch::jit {
7 
8 // onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0,
9 // so before converting the ints to tensors we need to cast them to floats.
PrepareDivisionForONNXOnBlock(Block * block)10 static void PrepareDivisionForONNXOnBlock(Block* block) {
11   for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
12     for (auto sub : it->blocks()) {
13       PrepareDivisionForONNXOnBlock(sub);
14     }
15     WithInsertPoint guard(*it);
16     auto* subgraph = it->owningGraph();
17 
18     if (it->matches("aten::div(int a, int b) -> float")) {
19       // Cast to Float before dividing
20       std::vector<Value*> floattensor_inputs =
21           fmap(it->inputs(), [&](Value* input) {
22             auto* longtensor =
23                 subgraph->insertNode(subgraph->createNumToTensor(input))
24                     ->output();
25             longtensor->node()->copyMetadata(input->node());
26             auto* nonblocking = subgraph->insertConstant(0);
27             auto* cast =
28                 subgraph->create(aten::_cast_Float, {longtensor, nonblocking});
29             cast->copyMetadata(*it);
30             return subgraph->insertNode(cast)->output();
31           });
32 
33       it->replaceInput(0, floattensor_inputs[0]);
34       it->replaceInput(1, floattensor_inputs[1]);
35       it->output()->setType(TensorType::fromNumberType(*FloatType::get()));
36     }
37   }
38 }
39 
PrepareDivisionForONNX(const std::shared_ptr<Graph> & graph)40 void PrepareDivisionForONNX(const std::shared_ptr<Graph>& graph) {
41   PrepareDivisionForONNXOnBlock(graph->block());
42   GRAPH_DUMP("After PrepareDivisionForONNX: ", graph);
43 }
44 
45 } // namespace torch::jit
46