xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/frozen_concat_linear.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/ir/alias_analysis.h>
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/ir/ir_views.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/jit/passes/frozen_concat_linear.h>
7 #include <torch/csrc/jit/passes/frozen_conv_folding.h>
8 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
9 #include <torch/csrc/jit/passes/remove_dropout.h>
10 #include <torch/csrc/jit/passes/utils/optimization_utils.h>
11 #include <torch/csrc/jit/runtime/graph_executor.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #else
16 #include <ATen/ops/cat.h>
17 #endif
18 
19 #include <unordered_set>
20 #include <utility>
21 #include <vector>
22 
23 namespace torch::jit {
24 namespace {
25 
26 using Tensor = at::Tensor;
27 
28 class ConcatLinearLayers {
29  public:
ConcatLinearLayers(std::shared_ptr<Graph> graph)30   explicit ConcatLinearLayers(std::shared_ptr<Graph> graph)
31       : graph_(std::move(graph)) {}
32 
run()33   bool run() {
34     handleBlockAndSubblocks(graph_->block());
35     return graph_modified;
36   }
37 
getAliasDb()38   AliasDb* getAliasDb() {
39     if (!aliasDb_) {
40       aliasDb_ = std::make_unique<AliasDb>(graph_);
41     }
42     return aliasDb_.get();
43   }
44 
collectConstantLinearLayers(Block * b,std::unordered_map<Value *,std::vector<Node * >> & grouped_linear_layers,std::vector<Value * > & ordered_tensor_inputs)45   void collectConstantLinearLayers(
46       Block* b,
47       std::unordered_map<Value*, std::vector<Node*>>& grouped_linear_layers,
48       std::vector<Value*>& ordered_tensor_inputs) {
49     // We are using an ordered list so that we only have to
50     // check if moving items forward is a valid move, not
51     // backwards. Otherwise we need to rebuild the aliasDb when we add values.
52 
53     for (Node* n : b->nodes()) {
54       // Grouping together all linear layers that use the same Tensor for input
55       if (n->kind() != aten::linear) {
56         continue;
57       }
58 
59       auto weight = n->namedInput("weight");
60       auto bias = n->namedInput("bias");
61       if (weight->type() == NoneType::get() ||
62           bias->type() == NoneType::get()) {
63         continue;
64       }
65 
66       if (nonConstantParameters(n)) {
67         continue;
68       }
69       auto weight_tensor = constant_as<Tensor>(weight).value();
70       if (!weight_tensor.device().is_cuda()) {
71         continue;
72       }
73 
74       Value* linear_input = n->inputs().at(0);
75       if (grouped_linear_layers.find(linear_input) ==
76           grouped_linear_layers.cend()) {
77         grouped_linear_layers.insert({linear_input, std::vector<Node*>()});
78         ordered_tensor_inputs.push_back(linear_input);
79       }
80       grouped_linear_layers.find(linear_input)->second.push_back(n);
81     }
82   }
83 
mergeLinearLayers(std::vector<Node * > & compatible_layers)84   void mergeLinearLayers(std::vector<Node*>& compatible_layers) {
85     graph_modified = true;
86     assert(!compatible_layers.empty());
87     Node* base_node = compatible_layers[0];
88 
89     // Scope needed to make sure we free the WithInsertPoint guard
90     // and reset the insert point before we delete `base_node`
91     Node* linear_node = nullptr;
92     {
93       WithInsertPoint guard(base_node);
94       auto weight_list = c10::fmap(compatible_layers, [](Node* n) {
95         return constant_as<Tensor>(n->namedInput("weight")).value();
96       });
97       Tensor cat_weight = at::cat(weight_list, /*dim=*/0);
98       Value* cat_weight_value = graph_->insertConstant(std::move(cat_weight));
99 
100       auto bias_list = c10::fmap(compatible_layers, [](Node* n) {
101         return constant_as<Tensor>(n->namedInput("bias")).value();
102       });
103       Tensor cat_bias = at::cat(bias_list, /*dim=*/0);
104       Value* cat_bias_value = graph_->insertConstant(std::move(cat_bias));
105 
106       auto tensor_input = base_node->inputs().at(0);
107       std::vector<Value*> linear_in = {
108           tensor_input, cat_weight_value, cat_bias_value};
109       linear_node = graph_->create(aten::linear, linear_in);
110       linear_node->insertBefore(base_node);
111     }
112 
113     // Update the outputs of the nodes
114     WithInsertPoint guard2(linear_node);
115     Value* neg1 = graph_->insertConstant(-1);
116     Value* one = graph_->insertConstant(1);
117 
118     int64_t slice_start = 0;
119     Value* slice_start_val = graph_->insertConstant(0);
120 
121     for (Node* orig_node : compatible_layers) {
122       // for each node in the compatible_layers list,
123       // slide the output of the combined linear layer
124       // and use it instead of the output of the original node
125 
126       Tensor weight_tensor =
127           constant_as<Tensor>(orig_node->namedInput("weight")).value();
128       int64_t slice_end = slice_start + weight_tensor.size(0);
129       Value* slice_end_val = graph_->insertConstant(slice_end);
130 
131       Node* slice = graph_->create(
132           aten::slice,
133           {linear_node->output(), neg1, slice_start_val, slice_end_val, one});
134       slice->insertAfter(linear_node);
135       orig_node->replaceAllUsesWith(slice);
136       orig_node->destroy();
137 
138       slice_start = slice_end;
139       slice_start_val = slice_end_val;
140     }
141   }
142 
isNonZeroDimEqual(Tensor & tensor_a,Tensor & tensor_b)143   bool isNonZeroDimEqual(Tensor& tensor_a, Tensor& tensor_b) {
144     if (tensor_a.dim() != tensor_b.dim()) {
145       return false;
146     }
147     for (int64_t i = 1; i < tensor_a.dim(); i++) {
148       if (tensor_a.size(i) != tensor_b.size(i)) {
149         return false;
150       }
151     }
152     return true;
153   }
154 
155   // Check the linear_layer_group of a tensor to find ones that can be
156   // combined
collectAndMergeLinearLayers(std::vector<Node * > & linear_layer_group)157   void collectAndMergeLinearLayers(std::vector<Node*>& linear_layer_group) {
158     std::unordered_set<Node*> checked_nodes;
159 
160     for (size_t i = 0; i < linear_layer_group.size(); i++) {
161       Node* base_node = linear_layer_group[i];
162       if (checked_nodes.count(base_node) != 0) {
163         continue;
164       }
165 
166       std::vector<Node*> compatible_layers;
167       compatible_layers.push_back(base_node);
168 
169       auto base_weight =
170           constant_as<Tensor>(base_node->namedInput("weight")).value();
171       auto base_bias =
172           constant_as<Tensor>(base_node->namedInput("bias")).value();
173 
174       // Now iterate over the rest of the users of the set to
175       // see if there is anything that we can coalesce `base_node` with.
176       for (size_t j = i + 1; j < linear_layer_group.size(); j++) {
177         auto node = linear_layer_group[j];
178         if (checked_nodes.count(node) != 0) {
179           continue;
180         }
181         auto weight = constant_as<Tensor>(node->namedInput("weight")).value();
182         auto bias = constant_as<Tensor>(node->namedInput("bias")).value();
183 
184         // For now we will just keep it simple and require matching types
185         // Type promotion might cause performance to actually decrease.
186         if (base_weight.dtype() != weight.dtype() ||
187             base_weight.device() != weight.device() ||
188             base_bias.dtype() != bias.dtype() ||
189             base_bias.device() != bias.device()) {
190           continue;
191         }
192 
193         if (!isNonZeroDimEqual(base_weight, weight) ||
194             !isNonZeroDimEqual(base_bias, bias)) {
195           continue;
196         }
197 
198         bool can_move_before_all = true;
199         for (auto n : compatible_layers) {
200           can_move_before_all &=
201               getAliasDb()->couldMoveBeforeTopologically(node, n);
202         }
203         if (!can_move_before_all) {
204           continue;
205         }
206 
207         // Found a node that is eligible for combination
208         compatible_layers.push_back(node);
209         checked_nodes.insert(node);
210       }
211       if (compatible_layers.size() == 1) {
212         continue; // No other layers to merge
213       }
214       mergeLinearLayers(compatible_layers);
215     }
216   }
217 
handleBlockAndSubblocks(Block * block)218   void handleBlockAndSubblocks(Block* block) {
219     for (auto node : block->nodes()) {
220       for (Block* subblock : node->blocks()) {
221         handleBlockAndSubblocks(subblock);
222       }
223     }
224 
225     // Processing for the block itself
226     std::unordered_map<Value*, std::vector<Node*>> grouped_linear_layers;
227     std::vector<Value*> ordered_tensor_inputs;
228     collectConstantLinearLayers(
229         block, grouped_linear_layers, ordered_tensor_inputs);
230 
231     // Reverse topological ordering is used to prevent the need to
232     // update the aliasDB
233     for (auto tensor_it = ordered_tensor_inputs.rbegin();
234          tensor_it != ordered_tensor_inputs.rend();
235          ++tensor_it) {
236       collectAndMergeLinearLayers(grouped_linear_layers.at(*tensor_it));
237     }
238   }
239 
240  private:
241   std::shared_ptr<Graph> graph_;
242   bool graph_modified = false;
243   std::unique_ptr<AliasDb> aliasDb_ = nullptr;
244 };
245 } // namespace
246 
FrozenConcatLinear(std::shared_ptr<Graph> & graph)247 TORCH_API bool FrozenConcatLinear(std::shared_ptr<Graph>& graph) {
248   ConcatLinearLayers concatLayers(graph);
249   GRAPH_DUMP("Before FrozenConcatLinear", graph);
250   bool changed = concatLayers.run();
251   if (changed) {
252     GRAPH_DUMP("After FrozenConcatLinear", graph);
253   }
254   return changed;
255 }
256 
257 } // namespace torch::jit
258