1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5
6 namespace torch::jit {
7
8 struct ChunkOutput {
ChunkOutputtorch::jit::ChunkOutput9 ChunkOutput(Value* v, size_t o) : val(v), offset(o){};
10 Value* val;
11 size_t offset;
12 };
13
getChunkOutputs(Node * chunk)14 static std::optional<std::vector<ChunkOutput>> getChunkOutputs(Node* chunk) {
15 std::vector<ChunkOutput> outputs;
16 for (auto list_use : chunk->output()->uses()) {
17 if (list_use.user->matches(
18 "aten::select(t[] list, int idx) -> t", attr::idx) &&
19 list_use.user->output()->type()->cast<TensorType>()) {
20 outputs.emplace_back(
21 list_use.user->output(),
22 list_use.user->get<int64_t>(attr::idx).value());
23 } else if (list_use.user->kind() == prim::ListUnpack) {
24 // This sometimes happens if the sizes can't be evenly divided by the
25 // number of chunks
26 if (static_cast<int64_t>(list_use.user->outputs().size()) !=
27 chunk->get<int64_t>(attr::chunks).value()) {
28 return std::nullopt;
29 }
30 auto unpack_outputs = list_use.user->outputs();
31 for (const auto i : c10::irange(unpack_outputs.size())) {
32 outputs.emplace_back(unpack_outputs[i], i);
33 }
34 } else {
35 return std::nullopt;
36 }
37 }
38 return outputs;
39 }
40
CanonicalizeOps(Block * block)41 static void CanonicalizeOps(Block* block) {
42 for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
43 ++it) {
44 for (auto sub : it->blocks())
45 CanonicalizeOps(sub);
46 if (it->matches(
47 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
48 it->matches(
49 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
50 it->matches("aten::mul(Tensor self, Tensor other) -> Tensor") ||
51 it->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
52 // Replace rank 0 Tensor constants with scalar constants.
53 if (auto other = it->get<at::Tensor>(attr::other)) {
54 if (other->dim() == 0) {
55 WithInsertPoint insert_guard{*it};
56 auto graph = it->owningGraph();
57 auto new_other = graph->insertConstant(other->item());
58 std::vector<Value*> inputs = it->inputs().vec();
59 inputs.at(1) = new_other;
60 Value* new_output =
61 graph->insertNode(graph->create(it->kind(), inputs))->output();
62 new_output->node()->copyMetadata(*it);
63 new_output->copyMetadata(it->output());
64 it->output()->replaceAllUsesWith(new_output);
65 }
66 }
67 } else if (it->matches(
68 "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
69 /*const_inputs=*/{attr::chunks, attr::dim})) {
70 // Replace aten::chunk (which returns a list) with ConstantChunk with the
71 // outputs unpacked.
72 if (auto orig_outputs = getChunkOutputs(*it)) {
73 WithInsertPoint guard(*it);
74 auto* self = it->namedInput(attr::self);
75 auto* graph = it->owningGraph();
76 const auto chunks = it->get<int64_t>(attr::chunks).value();
77 const auto dim = it->get<int64_t>(attr::dim).value();
78 auto* node =
79 graph->insertNode(graph->create(prim::ConstantChunk, chunks));
80 node->addInput(self);
81 node->i_(attr::chunks, chunks)->i_(attr::dim, dim);
82 node->copyMetadata(*it);
83 for (const auto& orig_out : *orig_outputs) {
84 orig_out.val->replaceAllUsesWith(node->outputs()[orig_out.offset]);
85 node->outputs()[orig_out.offset]->setType(orig_out.val->type());
86 }
87 }
88 }
89 }
90 }
91
CanonicalizeOps(const std::shared_ptr<Graph> & graph)92 void CanonicalizeOps(const std::shared_ptr<Graph>& graph) {
93 CanonicalizeOps(graph->block());
94 GRAPH_DUMP("After CanonicalizeOps: ", graph);
95 EliminateDeadCode(graph);
96 }
97
98 } // namespace torch::jit
99