xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/constant_pooling.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/constant_pooling.h>
2 
3 #include <ATen/core/symbol.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/ir/node_hashing.h>
7 #include <unordered_set>
8 
9 namespace torch::jit {
10 
11 namespace {
12 
13 // Very similar to the common subexpression elimination pass
14 // Move all constants to the beginning of the graph, and deduplicate
ConstantPooling(Block * block,std::unordered_set<Node *,HashNode,EqualNode> & constants,const AliasDb & aliasDb)15 void ConstantPooling(
16     Block* block,
17     std::unordered_set<Node*, HashNode, EqualNode>& constants,
18     const AliasDb& aliasDb) {
19   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
20     auto node = *it;
21     // node may be moved to a different block so advance iterator now
22     ++it;
23     if (!node->blocks().empty()) {
24       // Traverse sub-blocks.
25       for (auto block : node->blocks()) {
26         ConstantPooling(block, constants, aliasDb);
27       }
28       continue;
29     }
30 
31     if (node->kind() != prim::Constant) {
32       continue;
33     }
34 
35     // Check whether the same constant already exists.
36     auto subit = constants.insert(node);
37     if (!subit.second) {
38       auto existing = *subit.first;
39 
40       auto old_ivalue = toIValue(existing->output());
41       auto new_ivalue = toIValue(node->output());
42 
43       // if both values are the same object, we do not need to worry about
44       // changing the aliasing relationship
45       bool same_identity =
46           (old_ivalue && new_ivalue && (old_ivalue->is(new_ivalue)));
47 
48       if (!same_identity &&
49           !aliasDb.safeToChangeAliasingRelationship(
50               node->outputs(), existing->outputs())) {
51         continue;
52       }
53 
54       // constant exists, replace the uses of node, and destroy it.
55       node->replaceAllUsesWith(existing);
56       node->destroy();
57       continue;
58     }
59 
60     // Move the constant definition to the beginning of the graph.
61     auto first_node = node->owningGraph()->block()->nodes().front();
62     if (node != first_node)
63       node->moveBefore(first_node);
64   }
65 }
66 } // anonymous namespace
67 
ConstantPooling(const std::shared_ptr<Graph> & graph)68 void ConstantPooling(const std::shared_ptr<Graph>& graph) {
69   AliasDb aliasDb(graph);
70   std::unordered_set<Node*, HashNode, EqualNode> constants;
71   ConstantPooling(graph->block(), constants, aliasDb);
72 }
73 } // namespace torch::jit
74