xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/erase_number_types.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/erase_number_types.h>
2 
3 #include <torch/csrc/jit/ir/constants.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/dead_code_elimination.h>
6 
7 #include <ATen/ScalarOps.h>
8 
9 namespace torch::jit {
10 
SetNumTypeToTensorType(Value * v)11 static void SetNumTypeToTensorType(Value* v) {
12   if (v->type()->isSubtypeOf(*NumberType::get())) {
13     v->setType(TensorType::fromNumberType(*v->type()));
14   } else if (v->type()->isSubtypeOf(*BoolType::get())) {
15     v->setType(TensorType::fromBoolType());
16   }
17 }
18 
EraseNumberTypesOnBlock(Block * block)19 void EraseNumberTypesOnBlock(Block* block) {
20   for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
21        ++it) {
22     for (auto inp : it->inputs()) {
23       SetNumTypeToTensorType(inp);
24     }
25     for (auto sub : it->blocks()) {
26       EraseNumberTypesOnBlock(sub);
27     }
28     switch (it->kind()) {
29       case prim::Constant: {
30         // remove primitive constants, replacing with tensor equivalent
31         // ONNX does not support non-tensor constants
32         if (it->output()->type()->isSubtypeOf(*NumberType::get()) ||
33             it->output()->type()->isSubtypeOf(*BoolType::get())) {
34           at::Scalar s;
35           if (it->output()->type()->isSubtypeOf(*BoolType::get())) {
36             s = *constant_as<bool>(it->output());
37           } else {
38             s = *constant_as<at::Scalar>(it->output());
39           }
40 
41           WithInsertPoint guard(*it);
42           Value* r = block->owningGraph()->insertConstant(
43               scalar_to_tensor(s), std::nullopt, it->scope());
44           r->copyMetadata(it->output());
45           it->output()->replaceAllUsesWith(r);
46           it.destroyCurrent();
47         }
48       } break;
49       case aten::Bool:
50       case aten::Float:
51       case aten::Int:
52       case aten::FloatImplicit:
53       case aten::IntImplicit:
54       case aten::ScalarImplicit:
55       case prim::NumToTensor: {
56         it->output()->replaceAllUsesWith(it->inputs()[0]);
57         it.destroyCurrent();
58       } break;
59       default: {
60         for (auto o : it->outputs()) {
61           SetNumTypeToTensorType(o);
62         }
63       } break;
64     }
65   }
66 }
67 
EraseNumberTypes(const std::shared_ptr<Graph> & graph)68 void EraseNumberTypes(const std::shared_ptr<Graph>& graph) {
69   for (auto inp : graph->inputs()) {
70     SetNumTypeToTensorType(inp);
71   }
72   EraseNumberTypesOnBlock(graph->block());
73   GRAPH_DUMP("After EraseNumberTypes: ", graph);
74 }
75 } // namespace torch::jit
76