1 #include <torch/csrc/jit/passes/constant_pooling.h>
2
3 #include <ATen/core/symbol.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/ir/node_hashing.h>
7 #include <unordered_set>
8
9 namespace torch::jit {
10
11 namespace {
12
13 // Very similar to the common subexpression elimination pass
14 // Move all constants to the beginning of the graph, and deduplicate
ConstantPooling(Block * block,std::unordered_set<Node *,HashNode,EqualNode> & constants,const AliasDb & aliasDb)15 void ConstantPooling(
16 Block* block,
17 std::unordered_set<Node*, HashNode, EqualNode>& constants,
18 const AliasDb& aliasDb) {
19 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
20 auto node = *it;
21 // node may be moved to a different block so advance iterator now
22 ++it;
23 if (!node->blocks().empty()) {
24 // Traverse sub-blocks.
25 for (auto block : node->blocks()) {
26 ConstantPooling(block, constants, aliasDb);
27 }
28 continue;
29 }
30
31 if (node->kind() != prim::Constant) {
32 continue;
33 }
34
35 // Check whether the same constant already exists.
36 auto subit = constants.insert(node);
37 if (!subit.second) {
38 auto existing = *subit.first;
39
40 auto old_ivalue = toIValue(existing->output());
41 auto new_ivalue = toIValue(node->output());
42
43 // if both values are the same object, we do not need to worry about
44 // changing the aliasing relationship
45 bool same_identity =
46 (old_ivalue && new_ivalue && (old_ivalue->is(new_ivalue)));
47
48 if (!same_identity &&
49 !aliasDb.safeToChangeAliasingRelationship(
50 node->outputs(), existing->outputs())) {
51 continue;
52 }
53
54 // constant exists, replace the uses of node, and destroy it.
55 node->replaceAllUsesWith(existing);
56 node->destroy();
57 continue;
58 }
59
60 // Move the constant definition to the beginning of the graph.
61 auto first_node = node->owningGraph()->block()->nodes().front();
62 if (node != first_node)
63 node->moveBefore(first_node);
64 }
65 }
66 } // anonymous namespace
67
ConstantPooling(const std::shared_ptr<Graph> & graph)68 void ConstantPooling(const std::shared_ptr<Graph>& graph) {
69 AliasDb aliasDb(graph);
70 std::unordered_set<Node*, HashNode, EqualNode> constants;
71 ConstantPooling(graph->block(), constants, aliasDb);
72 }
73 } // namespace torch::jit
74