xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/quantization/dedup_module_uses.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/quantization/dedup_module_uses.h>
2 
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/quantization/helper.h>
5 
6 #include <stack>
7 
8 namespace torch {
9 namespace jit {
10 namespace {
11 class ModuleUseDeduper {
12  public:
ModuleUseDeduper(Module & module)13   ModuleUseDeduper(Module& module) : module_(module) {}
dedup()14   void dedup() {
15     for (auto& method : module_.get_methods()) {
16       const auto& graph = method.graph();
17       findModuleUses(graph.get());
18     }
19     dedupModuleUses();
20   }
21 
22  private:
23   // Analyze the code to record information represents
24   // uses of the module, which we'll use later to actually perform the dedup
25   // operation Please see the comments of member variables of the class for more
26   // information
findModuleUses(Graph * graph)27   void findModuleUses(Graph* graph) {
28     GRAPH_DUMP("Finding module uses for ", graph);
29 
30     std::stack<Block*> blocks_to_visit;
31     blocks_to_visit.push(graph->block());
32     Value* self = graph->inputs()[0];
33     while (!blocks_to_visit.empty()) {
34       Block* b = blocks_to_visit.top();
35       blocks_to_visit.pop();
36       for (Node* n : b->nodes()) {
37         for (Block* subblock : n->blocks()) {
38           blocks_to_visit.push(subblock);
39         }
40         if (n->kind() != prim::CallMethod) {
41           continue;
42         }
43         Value* instance = n->inputs()[0];
44         // boundary_val is the value we get when we trace back
45         // the GetAttr access chain until we hit the input of graph
46         // or a node that is not prim::GetAttr
47         auto path = getModuleAccessPath(instance, self);
48 
49         // path.size() == 0 means we're calling a method
50         // on self, we don't need to dedup uses of self
51         if (path.empty()) {
52           continue;
53         }
54         value_to_path_map_[instance] = path;
55         auto m = findChildModule(module_, path);
56         // If we fail to insert the module to the unique_modules_ set,
57         // which means there are uses of this module before this point,
58         // we'll have to rewrite the use
59         if (!unique_modules_.insert(m._ivalue()).second) {
60           uses_to_rewrite_.push_back(instance);
61           GRAPH_DEBUG("Found use to rewrite: ", instance->debugName());
62         }
63       }
64     }
65   }
66 
67   // Deduplicate module uses given the information we recorded before
dedupModuleUses()68   void dedupModuleUses() {
69     for (Value* v : uses_to_rewrite_) {
70       const auto& path = value_to_path_map_.at(v);
71       const auto& m = findChildModule(module_, path);
72       // add a clone of the child module to the parent of the duplicated module
73       const auto& child_name = addChildModule(module_, m, path);
74       TORCH_INTERNAL_ASSERT(v->node()->kind() == prim::GetAttr);
75       // change the name in GetAttr call
76       auto original_name = v->node()->s(attr::name);
77       v->node()->s_(attr::name, child_name);
78       GRAPH_UPDATE(
79           "Module use dedup: changing use of original module ",
80           original_name,
81           " to ",
82           child_name);
83     }
84   }
85 
addChildModule(Module & module,const Module & child_module,const std::vector<std::string> & path)86   std::string addChildModule(
87       Module& module,
88       const Module& child_module,
89       const std::vector<std::string>& path) {
90     TORCH_INTERNAL_ASSERT(
91         !path.empty(), "path must have at least one element.");
92     // Parent module of the leaf child module corresponding to
93     // the path
94     auto parent_of_leaf = findChildModule(
95         module, std::vector<std::string>(path.begin(), path.end() - 1));
96 
97     // Original name of the child module
98     const std::string& original_name = path[path.size() - 1];
99     int uid = 0;
100     std::string child_name = original_name + "_" + std::to_string(uid++);
101     while (parent_of_leaf.hasattr(child_name)) {
102       child_name = original_name + "_" + std::to_string(uid++);
103     }
104     parent_of_leaf.register_module(child_name, child_module.deepcopy());
105     return child_name;
106   }
107 
108   Module module_;
109   // Map from value of module instance to the list of names of submodules
110   // starting from the top level module, e.g. ["sub1", "sub2", "relu"]
111   // Also this is a cache of calling `getModuleAccessPath` of the value
112   std::unordered_map<Value*, std::vector<std::string>> value_to_path_map_;
113   // Set of unique modules that are used in the graphs
114   std::unordered_set<ModulePtr> unique_modules_;
115   // Values that represent the module instance(the use of the module)
116   // that we'll need to rewrite as a use of a cloned module
117   // instance
118   std::vector<Value*> uses_to_rewrite_;
119 };
120 
121 } // namespace
122 
DedupModuleUses(Module & module)123 void DedupModuleUses(Module& module) {
124   ModuleUseDeduper d(module);
125   d.dedup();
126 }
127 
128 } // namespace jit
129 } // namespace torch
130