1 #include <torch/csrc/jit/passes/batch_mm.h>
2
3 #include <ATen/core/functional.h>
4 #include <ATen/core/symbol.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/jit/ir/alias_analysis.h>
8 #include <torch/csrc/jit/ir/constants.h>
9 #include <torch/csrc/jit/passes/dead_code_elimination.h>
10 #include <torch/csrc/jit/passes/peephole.h>
11 #include <torch/csrc/jit/runtime/custom_operator.h>
12 #include <torch/csrc/jit/runtime/graph_iterator.h>
13
14 #include <ATen/ATen.h>
15 #include <algorithm>
16 #include <unordered_map>
17 #include <utility>
18
19 namespace torch::jit {
20
21 namespace {
aliasAnalysisIsSpecialCase()22 c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() {
23 return AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
24 }
25 } // namespace
26
27 // This pass looks for trees in the graph, where leaves are mm ops, and the
28 // inner vertices are add nodes. Once we have such a tree they can be reduced to
29 // two concats and a single mm (basically into a single multiply of a wide
30 // matrix, with a tall matrix). Such patterns show up mostly in backward of
31 // RNNs, since the derivative of many uses of matrix multiplies with same
32 // weights forms exactly such a tree (note that it's usually also highly
33 // imbalanced i.e. has O(n) depth).
34 //
35 // This (or any tree of adds of MMs):
36 //
37 // +------+ +------+ +------+ +------+ +------+
38 // | | | | | | | | | |
39 // | L1 | | R1 | + | L2 | | R2 | = | O |
40 // | | | | | | | | | |
41 // +------+ +------+ +------+ +------+ +------+
42 //
43 // can be basically transformed into a single MM which looks like this
44 // (we concat all lhs operands, concat rhs operands, do mm):
45 //
46 // +------+
47 // | |
48 // | R1 |
49 // | |
50 // +------+
51 // | |
52 // | R2 |
53 // | |
54 // +------+
55 // +------+------+ +------+
56 // | | | | |
57 // | L1 | L2 | | O |
58 // | | | | |
59 // +------+------+ +------+
60
61 // Note [Further optimizations]
62 // It would be straightforward to extend the TreeToken class to also detect if
63 // all MMs had the same lhs/rhs. In such case it's more efficient to expand the
64 // lhs and use bmm + sum instead of repeating it in memory via concat.
65
66 // Note [Overlapping trees]
67 // Additionally it wouldn't be too hard to add support for partially overlapping
68 // trees. Right now the it's forbidden in the algorithm (only a single tree will
69 // be allowed), so theoretically we might miss some optimization options,
70 // especially that the rejected tree could be much larger. I didn't implement
71 // that because it's not necessary for the simple RNN cases I saw, so I decided
72 // to keep stuff simple. If we ever get around implementing this, the right
73 // solution is probably to fuse MMs for the common part, and assume it's an
74 // input leaf for the outer two parts (I don't think it's beneficial to
75 // recompute, unless the subtree is super small, but let's not get into such
76 // details).
77
78 // The algorithm we're using is simple. We're iterating through the graph in the
79 // topological order and labeling nodes with TreeTokens. Then, we look for roots
80 // of the trees we formed and fuse them.
81
82 // Tunable parameter. Set to something larger if it turns out to be better.
83 static constexpr size_t min_fusion_size = 4;
84
have_same_shape(at::TensorList inputs)85 static bool have_same_shape(at::TensorList inputs) {
86 auto expected_sizes = inputs[0].sizes();
87 return (std::all_of(
88 inputs.begin(), inputs.end(), [expected_sizes](const at::Tensor& t) {
89 return t.sizes() == expected_sizes;
90 }));
91 }
92
should_be_transposed(at::TensorList inputs)93 static bool should_be_transposed(at::TensorList inputs) {
94 return (std::all_of(inputs.begin(), inputs.end(), [](const at::Tensor& t) {
95 return t.stride(0) == 1 && t.stride(1) == t.size(0);
96 }));
97 }
98
transpose_inputs(at::TensorList inputs)99 static std::vector<at::Tensor> transpose_inputs(at::TensorList inputs) {
100 return fmap(inputs, [](const at::Tensor& i) { return i.t(); });
101 }
102
shape_is_fast_for_reduce(const at::Tensor & lhs,const at::Tensor & rhs)103 static bool shape_is_fast_for_reduce(
104 const at::Tensor& lhs,
105 const at::Tensor& rhs) {
106 size_t l = lhs.size(0);
107 size_t m = lhs.size(1);
108 size_t r = rhs.size(1);
109 // Numbers obtained by some simple benchmarks of fp32 gemms on a TITAN V
110 return m < 512 || ((l < 256 && r < 256) || (l > 256 && r > 256));
111 }
112
113 RegisterOperators mm_tree_reduction_reg({Operator(
114 "prim::MMTreeReduce(...) -> Tensor",
__anon957c7f480502(Stack& stack) 115 [](Stack& stack) {
116 auto num_inputs = pop(stack).toInt();
117 std::vector<at::Tensor> inputs;
118 inputs.reserve(num_inputs);
119 for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) {
120 inputs.push_back(std::move(*it).toTensor());
121 }
122 drop(stack, num_inputs);
123
124 AT_ASSERT(!inputs.empty());
125 AT_ASSERT(inputs.size() % 2 == 0);
126 size_t side_num_elems = inputs.size() / 2;
127 auto lhs_inputs = at::TensorList(inputs).slice(0, side_num_elems);
128 auto rhs_inputs = at::TensorList(inputs).slice(side_num_elems);
129 // TODO: checking this is not free, so we should stop if this keeps
130 // failing
131 if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) &&
132 shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) {
133 // sometimes lhs_inputs or rhs_inputs are not contiguous, and that
134 // causes at::cat to go through slow path view them as contiguous if
135 // possible by transposing
136 bool lhs_input_transposed = should_be_transposed(lhs_inputs);
137 bool rhs_input_transposed = should_be_transposed(rhs_inputs);
138 at::Tensor lhs, rhs;
139 if (lhs_input_transposed) {
140 std::vector<at::Tensor> lhs_contig_inputs =
141 transpose_inputs(lhs_inputs);
142 lhs = at::cat(lhs_contig_inputs, /*dim*/ 0);
143 lhs = lhs.t();
144 } else {
145 lhs = at::cat(lhs_inputs, /*dim=*/1);
146 }
147 if (rhs_input_transposed) {
148 std::vector<at::Tensor> rhs_contig_inputs =
149 transpose_inputs(rhs_inputs);
150 rhs = at::cat(rhs_contig_inputs, /*dim*/ 1);
151 rhs = rhs.t();
152 } else {
153 rhs = at::cat(rhs_inputs, /*dim=*/0);
154 }
155 push(stack, at::mm(lhs, rhs));
156 } else {
157 auto acc = at::mm(inputs[0], inputs[side_num_elems]);
158 for (const auto i : c10::irange(1, side_num_elems)) {
159 acc.add_(at::mm(inputs[i], inputs[side_num_elems + i]));
160 }
161 push(stack, std::move(acc));
162 }
163 },
164 aliasAnalysisIsSpecialCase())});
165
166 // TreeTokens will be used to label nodes of the graph, if the nodes will fit
167 // our mm/add tree pattern. Basically we do dynamic programming on DAGs, where
168 // when we reach node N with inputs A and B, then A and B have already been
169 // processed, and we can try to unify their TreeTokens (if they have them)
170 // and build a larger tree.
171 struct TreeToken {
172 uint64_t tree_size = 0; // NOTE: measured in number of leaves i.e. mm ops
173 Node* node = nullptr;
174 bool is_root = false;
175
mmtorch::jit::TreeToken176 static TreeToken mm(Node* mm) {
177 TreeToken token;
178 token.tree_size = 1;
179 token.node = mm;
180 token.is_root = true;
181 return token;
182 }
183
184 // NB: the returned token might be invalid, so make sure to check its boolean
185 // value!
transposetorch::jit::TreeToken186 static TreeToken transpose(Node* t, TreeToken& inp_token) {
187 TreeToken token;
188 if (!inp_token.node->matches(
189 "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
190 return token;
191 }
192 token.tree_size = 1;
193 token.node = t;
194 token.is_root = true;
195 inp_token.is_root = false;
196 return token;
197 }
198
199 // NB: the returned token might be invalid, so make sure to check its boolean
200 // value!
addtorch::jit::TreeToken201 static TreeToken add(Node* add, TreeToken& l, TreeToken& r) {
202 TreeToken token;
203 // See Note [Overlapping trees]
204 if (&l == &r || !l.is_root || !r.is_root)
205 return token;
206 token.tree_size = l.tree_size + r.tree_size;
207 token.node = add;
208 token.is_root = true;
209 l.is_root = r.is_root =
210 false; // Reserve the subtrees, so they can't be used again.
211 return token;
212 }
213
operator booltorch::jit::TreeToken214 explicit operator bool() {
215 return is_root;
216 }
217
removeTransposesAndGatherMatmulstorch::jit::TreeToken218 std::vector<Node*> removeTransposesAndGatherMatmuls() {
219 std::vector<Node*> matmuls;
220 std::vector<Node*> queue{node};
221 Graph* graph = node->owningGraph();
222 while (!queue.empty()) {
223 auto n = queue.back();
224 queue.pop_back();
225 if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
226 matmuls.push_back(n);
227 } else if (n->matches("aten::t(Tensor self) -> Tensor")) {
228 Node* input_node = n->input()->node();
229 AT_ASSERT(input_node->matches(
230 "aten::mm(Tensor self, Tensor mat2) -> Tensor"));
231 // (AB)^T == B^TA^T
232 WithInsertPoint insert_guard{input_node};
233 Value* A = input_node->inputs()[0];
234 Value* B = input_node->inputs()[1];
235 Value* AT = graph->insert(aten::t, {A});
236 Value* BT = graph->insert(aten::t, {B});
237 Value* BTAT = graph->insert(aten::mm, {BT, AT});
238 n->output()->replaceAllUsesWith(BTAT);
239 matmuls.push_back(BTAT->node());
240 } else if (
241 n->matches(
242 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
243 queue.push_back(n->inputs()[0]->node());
244 queue.push_back(n->inputs()[1]->node());
245 } else {
246 AT_ASSERTM(false, "Unsupported node found in a BatchMM tree!");
247 }
248 }
249 return matmuls;
250 }
251 };
252
253 enum class Side { LHS, RHS };
254
BatchMMTreeReduce(Block * block,AliasDb & alias_db)255 static void BatchMMTreeReduce(Block* block, AliasDb& alias_db) {
256 auto graph = block->owningGraph();
257
258 // Look for trees in the block
259 std::unordered_map<Node*, TreeToken> tokens;
260 for (auto node : block->nodes()) {
261 if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
262 !alias_db.hasWriters(node)) {
263 tokens[node] = TreeToken::mm(node);
264 } else if (
265 node->matches("aten::t(Tensor self) -> Tensor") &&
266 !alias_db.hasWriters(node)) {
267 auto input_it = tokens.find(node->input()->node());
268 if (input_it != tokens.end()) {
269 tokens[node] = TreeToken::transpose(node, input_it->second);
270 }
271 } else if (
272 node->matches(
273 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") &&
274 !alias_db.hasWriters(node)) {
275 Node* lhs = node->inputs()[0]->node();
276 Node* rhs = node->inputs()[1]->node();
277 auto lhs_it = tokens.find(lhs);
278 auto rhs_it = tokens.find(rhs);
279 // See Note [Overlapping trees] (regarding the uses().size() == 1 check)
280 // We could treat a subtree with multiple uses as if it was overlapping.
281 // XXX: uses().size() == 1 is also something that guarantees that this
282 // transform is valid, because we know for sure that the none of these
283 // operands depend on the result of the other. If we were to remove this,
284 // we need to compute a transitive closure and actually check the
285 // dependencies.
286 if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
287 lhs->output()->uses().size() == 1 &&
288 rhs->output()->uses().size() == 1) {
289 if (auto token = TreeToken::add(node, lhs_it->second, rhs_it->second)) {
290 tokens[node] = token;
291 }
292 }
293 } else {
294 for (auto block : node->blocks()) {
295 BatchMMTreeReduce(block, alias_db);
296 }
297 }
298 }
299
300 // Merge trees we've found
301 for (auto& item : tokens) {
302 auto& root = item.second;
303 if (!root || root.tree_size < min_fusion_size)
304 continue;
305 auto matmuls = root.removeTransposesAndGatherMatmuls();
306 WithInsertPoint insert_guard{root.node};
307 Node* tree_reduce =
308 graph->insertNode(graph->create(Symbol::prim("MMTreeReduce")));
309 for (Node* matmul : matmuls) {
310 tree_reduce->addInput(matmul->inputs().at(0));
311 }
312 for (Node* matmul : matmuls) {
313 tree_reduce->addInput(matmul->inputs().at(1));
314 }
315 root.node->output()->replaceAllUsesWith(tree_reduce->output());
316 // NB: don't bother with cleaning up after yourself. We'll use DCE for that.
317 }
318 }
319
shape_is_fast_for_side(const at::Tensor & other_side_input)320 static bool shape_is_fast_for_side(const at::Tensor& other_side_input) {
321 // Cutoff chosed by benchmarking on a TITAN V
322 return other_side_input.numel() <= 1024 * 2048;
323 }
324
325 RegisterOperators mm_batch_side_reg({Operator(
326 prim::MMBatchSide,
__anon957c7f480602(const Node* node) 327 [](const Node* node) -> Operation {
328 size_t num_other_side_inputs = node->inputs().size() - 1;
329 Side single_side = static_cast<Side>(node->i(Symbol::attr("side")));
330 return [num_other_side_inputs, single_side](Stack& stack) {
331 at::Tensor side_input;
332 std::vector<at::Tensor> other_side_inputs;
333 other_side_inputs.reserve(num_other_side_inputs);
334 for (auto it = stack.end() - num_other_side_inputs; it != stack.end();
335 ++it) {
336 other_side_inputs.push_back(std::move(*it).toTensor());
337 }
338 drop(stack, num_other_side_inputs);
339 pop(stack, side_input);
340
341 auto any_other_input = other_side_inputs[0];
342 if (have_same_shape(other_side_inputs) &&
343 shape_is_fast_for_side(other_side_inputs[0])) {
344 auto other_side_input =
345 at::cat(other_side_inputs, single_side == Side::LHS ? 1 : 0);
346 auto mm_out = single_side == Side::LHS
347 ? side_input.mm(other_side_input)
348 : other_side_input.mm(side_input);
349 auto outputs = at::chunk(
350 mm_out,
351 num_other_side_inputs,
352 /*dim=*/single_side == Side::LHS ? 1 : 0);
353 stack.insert(
354 stack.end(),
355 std::make_move_iterator(outputs.begin()),
356 std::make_move_iterator(outputs.end()));
357 } else {
358 if (single_side == Side::LHS) {
359 for (at::Tensor& other : other_side_inputs) {
360 stack.emplace_back(side_input.mm(other));
361 }
362 } else {
363 for (at::Tensor& other : other_side_inputs) {
364 stack.emplace_back(other.mm(side_input));
365 }
366 }
367 }
368 };
369 },
370 aliasAnalysisIsSpecialCase())});
371
gatherIndependentMMUses(Value * value,AliasDb & alias_db)372 static std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
373 Value* value,
374 AliasDb& alias_db) {
375 const auto postprocess = [&](std::vector<Node*> mms) {
376 if (mms.empty()) {
377 return mms;
378 }
379 std::sort(mms.begin(), mms.end(), [](Node* n, Node* m) {
380 return n->isBefore(m);
381 });
382 // Filter out dependent MMs. This algorithm might do very badly if e.g. you
383 // have a lot of independent MMs, that depend on the first one, but I doubt
384 // this will be a common scenario.
385 for (const auto i : c10::irange(mms.size())) {
386 if (mms[i] == nullptr)
387 continue;
388 for (size_t j = i + 1; j < mms.size(); ++j) {
389 if (mms[j] == nullptr)
390 continue;
391 if (!alias_db.couldMoveBeforeTopologically(mms[j], mms[i])) {
392 mms[j] = nullptr;
393 }
394 }
395 }
396 return c10::filter(mms, [](Node* n) { return n != nullptr; });
397 };
398
399 Block* block = value->node()->owningBlock();
400 std::vector<Node*> lhses; // Will contain nodes where value is used as an lhs
401 std::vector<Node*> rhses; // Like above, but rhs
402 for (Use u : value->uses()) {
403 if (u.user->owningBlock() == block &&
404 u.user->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
405 !alias_db.hasWriters(u.user)) {
406 if (u.offset == 0 && u.user->inputs()[1] != value) {
407 lhses.push_back(u.user);
408 } else if (u.offset == 1 && u.user->inputs()[0] != value) {
409 rhses.push_back(u.user);
410 }
411 }
412 }
413 return std::make_pair(
414 postprocess(std::move(lhses)), postprocess(std::move(rhses)));
415 }
416
BatchMMSide(Block * block,AliasDb & alias_db)417 static void BatchMMSide(Block* block, AliasDb& alias_db) {
418 // NB: 8 is the current loop unrolling factor
419 static constexpr size_t how_many_is_many = 8;
420 const auto batch_side = [&](std::vector<Node*>& mms, Side side) {
421 AT_ASSERT(!mms.empty());
422 for (int64_t i = static_cast<int64_t>(mms.size()) - 2; i >= 0; --i) {
423 bool move_ok = alias_db.moveBeforeTopologicallyValid(mms[i], mms[i + 1]);
424 AT_ASSERT(move_ok);
425 }
426 WithInsertPoint insert_guard{mms[0]};
427 Graph* graph = mms[0]->owningGraph();
428 Node* batch_mm = graph->create(
429 prim::MMBatchSide,
430 /*inputs=*/{},
431 /*num_outputs=*/mms.size());
432 graph->insertNode(batch_mm);
433 batch_mm->i_(Symbol::attr("side"), static_cast<int>(side));
434 Value* const_side = mms[0]->inputs().at(side == Side::LHS ? 0 : 1);
435 batch_mm->addInput(const_side);
436 for (const auto i : c10::irange(mms.size())) {
437 batch_mm->addInput(mms[i]->inputs().at(side == Side::LHS ? 1 : 0));
438 mms[i]->output()->replaceAllUsesWith(batch_mm->outputs().at(i));
439 }
440 };
441
442 std::unordered_set<Value*> considered_values;
443 for (Node* node : block->nodes()) {
444 if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
445 !alias_db.hasWriters(node)) {
446 for (Value* input : node->inputs()) {
447 if (/*bool not_inserted = */ !considered_values.emplace(input).second) {
448 continue;
449 }
450 auto uses_with_many = gatherIndependentMMUses(input, alias_db);
451 if (uses_with_many.first.size() >= how_many_is_many) {
452 batch_side(uses_with_many.first, Side::LHS);
453 }
454 if (uses_with_many.second.size() >= how_many_is_many) {
455 batch_side(uses_with_many.second, Side::RHS);
456 }
457 }
458 } else {
459 for (Block* subblock : node->blocks()) {
460 BatchMMSide(subblock, alias_db);
461 }
462 }
463 }
464 }
465
hasMMOperators(std::shared_ptr<Graph> & graph)466 static bool hasMMOperators(std::shared_ptr<Graph>& graph) {
467 DepthFirstGraphNodeIterator it(graph);
468 Node* n = nullptr;
469 while ((n = it.next()) != nullptr) {
470 if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
471 return true;
472 }
473 }
474 return false;
475 }
476
BatchMM(std::shared_ptr<Graph> & graph)477 void BatchMM(std::shared_ptr<Graph>& graph) {
478 if (!hasMMOperators(graph)) {
479 return;
480 }
481 AliasDb alias_db(graph);
482 BatchMMTreeReduce(graph->block(), alias_db);
483 BatchMMSide(graph->block(), alias_db);
484 EliminateDeadCode(graph);
485 // It's possible that transpose rearrangements have created sequences of
486 // consecutive transposes that didn't exist before.
487
488 // tensor type properties are not guaranteed to be correct
489 PeepholeOptimize(graph, /*disable_shape_peepholes*/ true);
490 }
491
492 } // namespace torch::jit
493