1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/ir/alias_analysis.h>
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/ir/ir_views.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/jit/passes/frozen_concat_linear.h>
7 #include <torch/csrc/jit/passes/frozen_conv_folding.h>
8 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
9 #include <torch/csrc/jit/passes/remove_dropout.h>
10 #include <torch/csrc/jit/passes/utils/optimization_utils.h>
11 #include <torch/csrc/jit/runtime/graph_executor.h>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #else
16 #include <ATen/ops/cat.h>
17 #endif
18
19 #include <unordered_set>
20 #include <utility>
21 #include <vector>
22
23 namespace torch::jit {
24 namespace {
25
26 using Tensor = at::Tensor;
27
28 class ConcatLinearLayers {
29 public:
ConcatLinearLayers(std::shared_ptr<Graph> graph)30 explicit ConcatLinearLayers(std::shared_ptr<Graph> graph)
31 : graph_(std::move(graph)) {}
32
run()33 bool run() {
34 handleBlockAndSubblocks(graph_->block());
35 return graph_modified;
36 }
37
getAliasDb()38 AliasDb* getAliasDb() {
39 if (!aliasDb_) {
40 aliasDb_ = std::make_unique<AliasDb>(graph_);
41 }
42 return aliasDb_.get();
43 }
44
collectConstantLinearLayers(Block * b,std::unordered_map<Value *,std::vector<Node * >> & grouped_linear_layers,std::vector<Value * > & ordered_tensor_inputs)45 void collectConstantLinearLayers(
46 Block* b,
47 std::unordered_map<Value*, std::vector<Node*>>& grouped_linear_layers,
48 std::vector<Value*>& ordered_tensor_inputs) {
49 // We are using an ordered list so that we only have to
50 // check if moving items forward is a valid move, not
51 // backwards. Otherwise we need to rebuild the aliasDb when we add values.
52
53 for (Node* n : b->nodes()) {
54 // Grouping together all linear layers that use the same Tensor for input
55 if (n->kind() != aten::linear) {
56 continue;
57 }
58
59 auto weight = n->namedInput("weight");
60 auto bias = n->namedInput("bias");
61 if (weight->type() == NoneType::get() ||
62 bias->type() == NoneType::get()) {
63 continue;
64 }
65
66 if (nonConstantParameters(n)) {
67 continue;
68 }
69 auto weight_tensor = constant_as<Tensor>(weight).value();
70 if (!weight_tensor.device().is_cuda()) {
71 continue;
72 }
73
74 Value* linear_input = n->inputs().at(0);
75 if (grouped_linear_layers.find(linear_input) ==
76 grouped_linear_layers.cend()) {
77 grouped_linear_layers.insert({linear_input, std::vector<Node*>()});
78 ordered_tensor_inputs.push_back(linear_input);
79 }
80 grouped_linear_layers.find(linear_input)->second.push_back(n);
81 }
82 }
83
mergeLinearLayers(std::vector<Node * > & compatible_layers)84 void mergeLinearLayers(std::vector<Node*>& compatible_layers) {
85 graph_modified = true;
86 assert(!compatible_layers.empty());
87 Node* base_node = compatible_layers[0];
88
89 // Scope needed to make sure we free the WithInsertPoint guard
90 // and reset the insert point before we delete `base_node`
91 Node* linear_node = nullptr;
92 {
93 WithInsertPoint guard(base_node);
94 auto weight_list = c10::fmap(compatible_layers, [](Node* n) {
95 return constant_as<Tensor>(n->namedInput("weight")).value();
96 });
97 Tensor cat_weight = at::cat(weight_list, /*dim=*/0);
98 Value* cat_weight_value = graph_->insertConstant(std::move(cat_weight));
99
100 auto bias_list = c10::fmap(compatible_layers, [](Node* n) {
101 return constant_as<Tensor>(n->namedInput("bias")).value();
102 });
103 Tensor cat_bias = at::cat(bias_list, /*dim=*/0);
104 Value* cat_bias_value = graph_->insertConstant(std::move(cat_bias));
105
106 auto tensor_input = base_node->inputs().at(0);
107 std::vector<Value*> linear_in = {
108 tensor_input, cat_weight_value, cat_bias_value};
109 linear_node = graph_->create(aten::linear, linear_in);
110 linear_node->insertBefore(base_node);
111 }
112
113 // Update the outputs of the nodes
114 WithInsertPoint guard2(linear_node);
115 Value* neg1 = graph_->insertConstant(-1);
116 Value* one = graph_->insertConstant(1);
117
118 int64_t slice_start = 0;
119 Value* slice_start_val = graph_->insertConstant(0);
120
121 for (Node* orig_node : compatible_layers) {
122 // for each node in the compatible_layers list,
123 // slide the output of the combined linear layer
124 // and use it instead of the output of the original node
125
126 Tensor weight_tensor =
127 constant_as<Tensor>(orig_node->namedInput("weight")).value();
128 int64_t slice_end = slice_start + weight_tensor.size(0);
129 Value* slice_end_val = graph_->insertConstant(slice_end);
130
131 Node* slice = graph_->create(
132 aten::slice,
133 {linear_node->output(), neg1, slice_start_val, slice_end_val, one});
134 slice->insertAfter(linear_node);
135 orig_node->replaceAllUsesWith(slice);
136 orig_node->destroy();
137
138 slice_start = slice_end;
139 slice_start_val = slice_end_val;
140 }
141 }
142
isNonZeroDimEqual(Tensor & tensor_a,Tensor & tensor_b)143 bool isNonZeroDimEqual(Tensor& tensor_a, Tensor& tensor_b) {
144 if (tensor_a.dim() != tensor_b.dim()) {
145 return false;
146 }
147 for (int64_t i = 1; i < tensor_a.dim(); i++) {
148 if (tensor_a.size(i) != tensor_b.size(i)) {
149 return false;
150 }
151 }
152 return true;
153 }
154
155 // Check the linear_layer_group of a tensor to find ones that can be
156 // combined
collectAndMergeLinearLayers(std::vector<Node * > & linear_layer_group)157 void collectAndMergeLinearLayers(std::vector<Node*>& linear_layer_group) {
158 std::unordered_set<Node*> checked_nodes;
159
160 for (size_t i = 0; i < linear_layer_group.size(); i++) {
161 Node* base_node = linear_layer_group[i];
162 if (checked_nodes.count(base_node) != 0) {
163 continue;
164 }
165
166 std::vector<Node*> compatible_layers;
167 compatible_layers.push_back(base_node);
168
169 auto base_weight =
170 constant_as<Tensor>(base_node->namedInput("weight")).value();
171 auto base_bias =
172 constant_as<Tensor>(base_node->namedInput("bias")).value();
173
174 // Now iterate over the rest of the users of the set to
175 // see if there is anything that we can coalesce `base_node` with.
176 for (size_t j = i + 1; j < linear_layer_group.size(); j++) {
177 auto node = linear_layer_group[j];
178 if (checked_nodes.count(node) != 0) {
179 continue;
180 }
181 auto weight = constant_as<Tensor>(node->namedInput("weight")).value();
182 auto bias = constant_as<Tensor>(node->namedInput("bias")).value();
183
184 // For now we will just keep it simple and require matching types
185 // Type promotion might cause performance to actually decrease.
186 if (base_weight.dtype() != weight.dtype() ||
187 base_weight.device() != weight.device() ||
188 base_bias.dtype() != bias.dtype() ||
189 base_bias.device() != bias.device()) {
190 continue;
191 }
192
193 if (!isNonZeroDimEqual(base_weight, weight) ||
194 !isNonZeroDimEqual(base_bias, bias)) {
195 continue;
196 }
197
198 bool can_move_before_all = true;
199 for (auto n : compatible_layers) {
200 can_move_before_all &=
201 getAliasDb()->couldMoveBeforeTopologically(node, n);
202 }
203 if (!can_move_before_all) {
204 continue;
205 }
206
207 // Found a node that is eligible for combination
208 compatible_layers.push_back(node);
209 checked_nodes.insert(node);
210 }
211 if (compatible_layers.size() == 1) {
212 continue; // No other layers to merge
213 }
214 mergeLinearLayers(compatible_layers);
215 }
216 }
217
handleBlockAndSubblocks(Block * block)218 void handleBlockAndSubblocks(Block* block) {
219 for (auto node : block->nodes()) {
220 for (Block* subblock : node->blocks()) {
221 handleBlockAndSubblocks(subblock);
222 }
223 }
224
225 // Processing for the block itself
226 std::unordered_map<Value*, std::vector<Node*>> grouped_linear_layers;
227 std::vector<Value*> ordered_tensor_inputs;
228 collectConstantLinearLayers(
229 block, grouped_linear_layers, ordered_tensor_inputs);
230
231 // Reverse topological ordering is used to prevent the need to
232 // update the aliasDB
233 for (auto tensor_it = ordered_tensor_inputs.rbegin();
234 tensor_it != ordered_tensor_inputs.rend();
235 ++tensor_it) {
236 collectAndMergeLinearLayers(grouped_linear_layers.at(*tensor_it));
237 }
238 }
239
240 private:
241 std::shared_ptr<Graph> graph_;
242 bool graph_modified = false;
243 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
244 };
245 } // namespace
246
FrozenConcatLinear(std::shared_ptr<Graph> & graph)247 TORCH_API bool FrozenConcatLinear(std::shared_ptr<Graph>& graph) {
248 ConcatLinearLayers concatLayers(graph);
249 GRAPH_DUMP("Before FrozenConcatLinear", graph);
250 bool changed = concatLayers.run();
251 if (changed) {
252 GRAPH_DUMP("After FrozenConcatLinear", graph);
253 }
254 return changed;
255 }
256
257 } // namespace torch::jit
258