xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/canonicalize_graph_fuser_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 
6 namespace torch::jit {
7 
8 struct ChunkOutput {
ChunkOutputtorch::jit::ChunkOutput9   ChunkOutput(Value* v, size_t o) : val(v), offset(o){};
10   Value* val;
11   size_t offset;
12 };
13 
getChunkOutputs(Node * chunk)14 static std::optional<std::vector<ChunkOutput>> getChunkOutputs(Node* chunk) {
15   std::vector<ChunkOutput> outputs;
16   for (auto list_use : chunk->output()->uses()) {
17     if (list_use.user->matches(
18             "aten::select(t[] list, int idx) -> t", attr::idx) &&
19         list_use.user->output()->type()->cast<TensorType>()) {
20       outputs.emplace_back(
21           list_use.user->output(),
22           list_use.user->get<int64_t>(attr::idx).value());
23     } else if (list_use.user->kind() == prim::ListUnpack) {
24       // This sometimes happens if the sizes can't be evenly divided by the
25       // number of chunks
26       if (static_cast<int64_t>(list_use.user->outputs().size()) !=
27           chunk->get<int64_t>(attr::chunks).value()) {
28         return std::nullopt;
29       }
30       auto unpack_outputs = list_use.user->outputs();
31       for (const auto i : c10::irange(unpack_outputs.size())) {
32         outputs.emplace_back(unpack_outputs[i], i);
33       }
34     } else {
35       return std::nullopt;
36     }
37   }
38   return outputs;
39 }
40 
CanonicalizeOps(Block * block)41 static void CanonicalizeOps(Block* block) {
42   for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
43        ++it) {
44     for (auto sub : it->blocks())
45       CanonicalizeOps(sub);
46     if (it->matches(
47             "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
48         it->matches(
49             "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
50         it->matches("aten::mul(Tensor self, Tensor other) -> Tensor") ||
51         it->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
52       // Replace rank 0 Tensor constants with scalar constants.
53       if (auto other = it->get<at::Tensor>(attr::other)) {
54         if (other->dim() == 0) {
55           WithInsertPoint insert_guard{*it};
56           auto graph = it->owningGraph();
57           auto new_other = graph->insertConstant(other->item());
58           std::vector<Value*> inputs = it->inputs().vec();
59           inputs.at(1) = new_other;
60           Value* new_output =
61               graph->insertNode(graph->create(it->kind(), inputs))->output();
62           new_output->node()->copyMetadata(*it);
63           new_output->copyMetadata(it->output());
64           it->output()->replaceAllUsesWith(new_output);
65         }
66       }
67     } else if (it->matches(
68                    "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
69                    /*const_inputs=*/{attr::chunks, attr::dim})) {
70       // Replace aten::chunk (which returns a list) with ConstantChunk with the
71       // outputs unpacked.
72       if (auto orig_outputs = getChunkOutputs(*it)) {
73         WithInsertPoint guard(*it);
74         auto* self = it->namedInput(attr::self);
75         auto* graph = it->owningGraph();
76         const auto chunks = it->get<int64_t>(attr::chunks).value();
77         const auto dim = it->get<int64_t>(attr::dim).value();
78         auto* node =
79             graph->insertNode(graph->create(prim::ConstantChunk, chunks));
80         node->addInput(self);
81         node->i_(attr::chunks, chunks)->i_(attr::dim, dim);
82         node->copyMetadata(*it);
83         for (const auto& orig_out : *orig_outputs) {
84           orig_out.val->replaceAllUsesWith(node->outputs()[orig_out.offset]);
85           node->outputs()[orig_out.offset]->setType(orig_out.val->type());
86         }
87       }
88     }
89   }
90 }
91 
CanonicalizeOps(const std::shared_ptr<Graph> & graph)92 void CanonicalizeOps(const std::shared_ptr<Graph>& graph) {
93   CanonicalizeOps(graph->block());
94   GRAPH_DUMP("After CanonicalizeOps: ", graph);
95   EliminateDeadCode(graph);
96 }
97 
98 } // namespace torch::jit
99