xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/batch_mm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/batch_mm.h>
2 
3 #include <ATen/core/functional.h>
4 #include <ATen/core/symbol.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/jit/ir/alias_analysis.h>
8 #include <torch/csrc/jit/ir/constants.h>
9 #include <torch/csrc/jit/passes/dead_code_elimination.h>
10 #include <torch/csrc/jit/passes/peephole.h>
11 #include <torch/csrc/jit/runtime/custom_operator.h>
12 #include <torch/csrc/jit/runtime/graph_iterator.h>
13 
14 #include <ATen/ATen.h>
15 #include <algorithm>
16 #include <unordered_map>
17 #include <utility>
18 
19 namespace torch::jit {
20 
21 namespace {
aliasAnalysisIsSpecialCase()22 c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() {
23   return AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
24 }
25 } // namespace
26 
27 // This pass looks for trees in the graph, where leaves are mm ops, and the
28 // inner vertices are add nodes. Once we have such a tree they can be reduced to
29 // two concats and a single mm (basically into a single multiply of a wide
30 // matrix, with a tall matrix). Such patterns show up mostly in backward of
31 // RNNs, since the derivative of many uses of matrix multiplies with same
32 // weights forms exactly such a tree (note that it's usually also highly
33 // imbalanced i.e. has O(n) depth).
34 //
35 // This (or any tree of adds of MMs):
36 //
37 // +------+ +------+   +------+ +------+   +------+
38 // |      | |      |   |      | |      |   |      |
39 // |  L1  | |  R1  | + |  L2  | |  R2  | = |  O   |
40 // |      | |      |   |      | |      |   |      |
41 // +------+ +------+   +------+ +------+   +------+
42 //
43 // can be basically transformed into a single MM which looks like this
44 // (we concat all lhs operands, concat rhs operands, do mm):
45 //
46 //                 +------+
47 //                 |      |
48 //                 |  R1  |
49 //                 |      |
50 //                 +------+
51 //                 |      |
52 //                 |  R2  |
53 //                 |      |
54 //                 +------+
55 // +------+------+ +------+
56 // |      |      | |      |
57 // |  L1  |  L2  | |  O   |
58 // |      |      | |      |
59 // +------+------+ +------+
60 
61 // Note [Further optimizations]
62 // It would be straightforward to extend the TreeToken class to also detect if
63 // all MMs had the same lhs/rhs. In such case it's more efficient to expand the
64 // lhs and use bmm + sum instead of repeating it in memory via concat.
65 
66 // Note [Overlapping trees]
67 // Additionally it wouldn't be too hard to add support for partially overlapping
68 // trees. Right now the it's forbidden in the algorithm (only a single tree will
69 // be allowed), so theoretically we might miss some optimization options,
70 // especially that the rejected tree could be much larger. I didn't implement
71 // that because it's not necessary for the simple RNN cases I saw, so I decided
72 // to keep stuff simple. If we ever get around implementing this, the right
73 // solution is probably to fuse MMs for the common part, and assume it's an
74 // input leaf for the outer two parts (I don't think it's beneficial to
75 // recompute, unless the subtree is super small, but let's not get into such
76 // details).
77 
78 // The algorithm we're using is simple. We're iterating through the graph in the
79 // topological order and labeling nodes with TreeTokens. Then, we look for roots
80 // of the trees we formed and fuse them.
81 
82 // Tunable parameter. Set to something larger if it turns out to be better.
83 static constexpr size_t min_fusion_size = 4;
84 
have_same_shape(at::TensorList inputs)85 static bool have_same_shape(at::TensorList inputs) {
86   auto expected_sizes = inputs[0].sizes();
87   return (std::all_of(
88       inputs.begin(), inputs.end(), [expected_sizes](const at::Tensor& t) {
89         return t.sizes() == expected_sizes;
90       }));
91 }
92 
should_be_transposed(at::TensorList inputs)93 static bool should_be_transposed(at::TensorList inputs) {
94   return (std::all_of(inputs.begin(), inputs.end(), [](const at::Tensor& t) {
95     return t.stride(0) == 1 && t.stride(1) == t.size(0);
96   }));
97 }
98 
transpose_inputs(at::TensorList inputs)99 static std::vector<at::Tensor> transpose_inputs(at::TensorList inputs) {
100   return fmap(inputs, [](const at::Tensor& i) { return i.t(); });
101 }
102 
shape_is_fast_for_reduce(const at::Tensor & lhs,const at::Tensor & rhs)103 static bool shape_is_fast_for_reduce(
104     const at::Tensor& lhs,
105     const at::Tensor& rhs) {
106   size_t l = lhs.size(0);
107   size_t m = lhs.size(1);
108   size_t r = rhs.size(1);
109   // Numbers obtained by some simple benchmarks of fp32 gemms on a TITAN V
110   return m < 512 || ((l < 256 && r < 256) || (l > 256 && r > 256));
111 }
112 
113 RegisterOperators mm_tree_reduction_reg({Operator(
114     "prim::MMTreeReduce(...) -> Tensor",
__anon957c7f480502(Stack& stack) 115     [](Stack& stack) {
116       auto num_inputs = pop(stack).toInt();
117       std::vector<at::Tensor> inputs;
118       inputs.reserve(num_inputs);
119       for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) {
120         inputs.push_back(std::move(*it).toTensor());
121       }
122       drop(stack, num_inputs);
123 
124       AT_ASSERT(!inputs.empty());
125       AT_ASSERT(inputs.size() % 2 == 0);
126       size_t side_num_elems = inputs.size() / 2;
127       auto lhs_inputs = at::TensorList(inputs).slice(0, side_num_elems);
128       auto rhs_inputs = at::TensorList(inputs).slice(side_num_elems);
129       // TODO: checking this is not free, so we should stop if this keeps
130       // failing
131       if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) &&
132           shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) {
133         // sometimes lhs_inputs or rhs_inputs are not contiguous, and that
134         // causes at::cat to go through slow path view them as contiguous if
135         // possible by transposing
136         bool lhs_input_transposed = should_be_transposed(lhs_inputs);
137         bool rhs_input_transposed = should_be_transposed(rhs_inputs);
138         at::Tensor lhs, rhs;
139         if (lhs_input_transposed) {
140           std::vector<at::Tensor> lhs_contig_inputs =
141               transpose_inputs(lhs_inputs);
142           lhs = at::cat(lhs_contig_inputs, /*dim*/ 0);
143           lhs = lhs.t();
144         } else {
145           lhs = at::cat(lhs_inputs, /*dim=*/1);
146         }
147         if (rhs_input_transposed) {
148           std::vector<at::Tensor> rhs_contig_inputs =
149               transpose_inputs(rhs_inputs);
150           rhs = at::cat(rhs_contig_inputs, /*dim*/ 1);
151           rhs = rhs.t();
152         } else {
153           rhs = at::cat(rhs_inputs, /*dim=*/0);
154         }
155         push(stack, at::mm(lhs, rhs));
156       } else {
157         auto acc = at::mm(inputs[0], inputs[side_num_elems]);
158         for (const auto i : c10::irange(1, side_num_elems)) {
159           acc.add_(at::mm(inputs[i], inputs[side_num_elems + i]));
160         }
161         push(stack, std::move(acc));
162       }
163     },
164     aliasAnalysisIsSpecialCase())});
165 
166 // TreeTokens will be used to label nodes of the graph, if the nodes will fit
167 // our mm/add tree pattern. Basically we do dynamic programming on DAGs, where
168 // when we reach node N with inputs A and B, then A and B have already been
169 // processed, and we can try to unify their TreeTokens (if they have them)
170 // and build a larger tree.
171 struct TreeToken {
172   uint64_t tree_size = 0; // NOTE: measured in number of leaves i.e. mm ops
173   Node* node = nullptr;
174   bool is_root = false;
175 
mmtorch::jit::TreeToken176   static TreeToken mm(Node* mm) {
177     TreeToken token;
178     token.tree_size = 1;
179     token.node = mm;
180     token.is_root = true;
181     return token;
182   }
183 
184   // NB: the returned token might be invalid, so make sure to check its boolean
185   // value!
transposetorch::jit::TreeToken186   static TreeToken transpose(Node* t, TreeToken& inp_token) {
187     TreeToken token;
188     if (!inp_token.node->matches(
189             "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
190       return token;
191     }
192     token.tree_size = 1;
193     token.node = t;
194     token.is_root = true;
195     inp_token.is_root = false;
196     return token;
197   }
198 
199   // NB: the returned token might be invalid, so make sure to check its boolean
200   // value!
addtorch::jit::TreeToken201   static TreeToken add(Node* add, TreeToken& l, TreeToken& r) {
202     TreeToken token;
203     // See Note [Overlapping trees]
204     if (&l == &r || !l.is_root || !r.is_root)
205       return token;
206     token.tree_size = l.tree_size + r.tree_size;
207     token.node = add;
208     token.is_root = true;
209     l.is_root = r.is_root =
210         false; // Reserve the subtrees, so they can't be used again.
211     return token;
212   }
213 
operator booltorch::jit::TreeToken214   explicit operator bool() {
215     return is_root;
216   }
217 
removeTransposesAndGatherMatmulstorch::jit::TreeToken218   std::vector<Node*> removeTransposesAndGatherMatmuls() {
219     std::vector<Node*> matmuls;
220     std::vector<Node*> queue{node};
221     Graph* graph = node->owningGraph();
222     while (!queue.empty()) {
223       auto n = queue.back();
224       queue.pop_back();
225       if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
226         matmuls.push_back(n);
227       } else if (n->matches("aten::t(Tensor self) -> Tensor")) {
228         Node* input_node = n->input()->node();
229         AT_ASSERT(input_node->matches(
230             "aten::mm(Tensor self, Tensor mat2) -> Tensor"));
231         // (AB)^T == B^TA^T
232         WithInsertPoint insert_guard{input_node};
233         Value* A = input_node->inputs()[0];
234         Value* B = input_node->inputs()[1];
235         Value* AT = graph->insert(aten::t, {A});
236         Value* BT = graph->insert(aten::t, {B});
237         Value* BTAT = graph->insert(aten::mm, {BT, AT});
238         n->output()->replaceAllUsesWith(BTAT);
239         matmuls.push_back(BTAT->node());
240       } else if (
241           n->matches(
242               "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
243         queue.push_back(n->inputs()[0]->node());
244         queue.push_back(n->inputs()[1]->node());
245       } else {
246         AT_ASSERTM(false, "Unsupported node found in a BatchMM tree!");
247       }
248     }
249     return matmuls;
250   }
251 };
252 
253 enum class Side { LHS, RHS };
254 
BatchMMTreeReduce(Block * block,AliasDb & alias_db)255 static void BatchMMTreeReduce(Block* block, AliasDb& alias_db) {
256   auto graph = block->owningGraph();
257 
258   // Look for trees in the block
259   std::unordered_map<Node*, TreeToken> tokens;
260   for (auto node : block->nodes()) {
261     if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
262         !alias_db.hasWriters(node)) {
263       tokens[node] = TreeToken::mm(node);
264     } else if (
265         node->matches("aten::t(Tensor self) -> Tensor") &&
266         !alias_db.hasWriters(node)) {
267       auto input_it = tokens.find(node->input()->node());
268       if (input_it != tokens.end()) {
269         tokens[node] = TreeToken::transpose(node, input_it->second);
270       }
271     } else if (
272         node->matches(
273             "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") &&
274         !alias_db.hasWriters(node)) {
275       Node* lhs = node->inputs()[0]->node();
276       Node* rhs = node->inputs()[1]->node();
277       auto lhs_it = tokens.find(lhs);
278       auto rhs_it = tokens.find(rhs);
279       // See Note [Overlapping trees] (regarding the uses().size() == 1 check)
280       // We could treat a subtree with multiple uses as if it was overlapping.
281       // XXX: uses().size() == 1 is also something that guarantees that this
282       // transform is valid, because we know for sure that the none of these
283       // operands depend on the result of the other. If we were to remove this,
284       // we need to compute a transitive closure and actually check the
285       // dependencies.
286       if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
287           lhs->output()->uses().size() == 1 &&
288           rhs->output()->uses().size() == 1) {
289         if (auto token = TreeToken::add(node, lhs_it->second, rhs_it->second)) {
290           tokens[node] = token;
291         }
292       }
293     } else {
294       for (auto block : node->blocks()) {
295         BatchMMTreeReduce(block, alias_db);
296       }
297     }
298   }
299 
300   // Merge trees we've found
301   for (auto& item : tokens) {
302     auto& root = item.second;
303     if (!root || root.tree_size < min_fusion_size)
304       continue;
305     auto matmuls = root.removeTransposesAndGatherMatmuls();
306     WithInsertPoint insert_guard{root.node};
307     Node* tree_reduce =
308         graph->insertNode(graph->create(Symbol::prim("MMTreeReduce")));
309     for (Node* matmul : matmuls) {
310       tree_reduce->addInput(matmul->inputs().at(0));
311     }
312     for (Node* matmul : matmuls) {
313       tree_reduce->addInput(matmul->inputs().at(1));
314     }
315     root.node->output()->replaceAllUsesWith(tree_reduce->output());
316     // NB: don't bother with cleaning up after yourself. We'll use DCE for that.
317   }
318 }
319 
shape_is_fast_for_side(const at::Tensor & other_side_input)320 static bool shape_is_fast_for_side(const at::Tensor& other_side_input) {
321   // Cutoff chosed by benchmarking on a TITAN V
322   return other_side_input.numel() <= 1024 * 2048;
323 }
324 
325 RegisterOperators mm_batch_side_reg({Operator(
326     prim::MMBatchSide,
__anon957c7f480602(const Node* node) 327     [](const Node* node) -> Operation {
328       size_t num_other_side_inputs = node->inputs().size() - 1;
329       Side single_side = static_cast<Side>(node->i(Symbol::attr("side")));
330       return [num_other_side_inputs, single_side](Stack& stack) {
331         at::Tensor side_input;
332         std::vector<at::Tensor> other_side_inputs;
333         other_side_inputs.reserve(num_other_side_inputs);
334         for (auto it = stack.end() - num_other_side_inputs; it != stack.end();
335              ++it) {
336           other_side_inputs.push_back(std::move(*it).toTensor());
337         }
338         drop(stack, num_other_side_inputs);
339         pop(stack, side_input);
340 
341         auto any_other_input = other_side_inputs[0];
342         if (have_same_shape(other_side_inputs) &&
343             shape_is_fast_for_side(other_side_inputs[0])) {
344           auto other_side_input =
345               at::cat(other_side_inputs, single_side == Side::LHS ? 1 : 0);
346           auto mm_out = single_side == Side::LHS
347               ? side_input.mm(other_side_input)
348               : other_side_input.mm(side_input);
349           auto outputs = at::chunk(
350               mm_out,
351               num_other_side_inputs,
352               /*dim=*/single_side == Side::LHS ? 1 : 0);
353           stack.insert(
354               stack.end(),
355               std::make_move_iterator(outputs.begin()),
356               std::make_move_iterator(outputs.end()));
357         } else {
358           if (single_side == Side::LHS) {
359             for (at::Tensor& other : other_side_inputs) {
360               stack.emplace_back(side_input.mm(other));
361             }
362           } else {
363             for (at::Tensor& other : other_side_inputs) {
364               stack.emplace_back(other.mm(side_input));
365             }
366           }
367         }
368       };
369     },
370     aliasAnalysisIsSpecialCase())});
371 
gatherIndependentMMUses(Value * value,AliasDb & alias_db)372 static std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
373     Value* value,
374     AliasDb& alias_db) {
375   const auto postprocess = [&](std::vector<Node*> mms) {
376     if (mms.empty()) {
377       return mms;
378     }
379     std::sort(mms.begin(), mms.end(), [](Node* n, Node* m) {
380       return n->isBefore(m);
381     });
382     // Filter out dependent MMs. This algorithm might do very badly if e.g. you
383     // have a lot of independent MMs, that depend on the first one, but I doubt
384     // this will be a common scenario.
385     for (const auto i : c10::irange(mms.size())) {
386       if (mms[i] == nullptr)
387         continue;
388       for (size_t j = i + 1; j < mms.size(); ++j) {
389         if (mms[j] == nullptr)
390           continue;
391         if (!alias_db.couldMoveBeforeTopologically(mms[j], mms[i])) {
392           mms[j] = nullptr;
393         }
394       }
395     }
396     return c10::filter(mms, [](Node* n) { return n != nullptr; });
397   };
398 
399   Block* block = value->node()->owningBlock();
400   std::vector<Node*> lhses; // Will contain nodes where value is used as an lhs
401   std::vector<Node*> rhses; // Like above, but rhs
402   for (Use u : value->uses()) {
403     if (u.user->owningBlock() == block &&
404         u.user->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
405         !alias_db.hasWriters(u.user)) {
406       if (u.offset == 0 && u.user->inputs()[1] != value) {
407         lhses.push_back(u.user);
408       } else if (u.offset == 1 && u.user->inputs()[0] != value) {
409         rhses.push_back(u.user);
410       }
411     }
412   }
413   return std::make_pair(
414       postprocess(std::move(lhses)), postprocess(std::move(rhses)));
415 }
416 
BatchMMSide(Block * block,AliasDb & alias_db)417 static void BatchMMSide(Block* block, AliasDb& alias_db) {
418   // NB: 8 is the current loop unrolling factor
419   static constexpr size_t how_many_is_many = 8;
420   const auto batch_side = [&](std::vector<Node*>& mms, Side side) {
421     AT_ASSERT(!mms.empty());
422     for (int64_t i = static_cast<int64_t>(mms.size()) - 2; i >= 0; --i) {
423       bool move_ok = alias_db.moveBeforeTopologicallyValid(mms[i], mms[i + 1]);
424       AT_ASSERT(move_ok);
425     }
426     WithInsertPoint insert_guard{mms[0]};
427     Graph* graph = mms[0]->owningGraph();
428     Node* batch_mm = graph->create(
429         prim::MMBatchSide,
430         /*inputs=*/{},
431         /*num_outputs=*/mms.size());
432     graph->insertNode(batch_mm);
433     batch_mm->i_(Symbol::attr("side"), static_cast<int>(side));
434     Value* const_side = mms[0]->inputs().at(side == Side::LHS ? 0 : 1);
435     batch_mm->addInput(const_side);
436     for (const auto i : c10::irange(mms.size())) {
437       batch_mm->addInput(mms[i]->inputs().at(side == Side::LHS ? 1 : 0));
438       mms[i]->output()->replaceAllUsesWith(batch_mm->outputs().at(i));
439     }
440   };
441 
442   std::unordered_set<Value*> considered_values;
443   for (Node* node : block->nodes()) {
444     if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
445         !alias_db.hasWriters(node)) {
446       for (Value* input : node->inputs()) {
447         if (/*bool not_inserted = */ !considered_values.emplace(input).second) {
448           continue;
449         }
450         auto uses_with_many = gatherIndependentMMUses(input, alias_db);
451         if (uses_with_many.first.size() >= how_many_is_many) {
452           batch_side(uses_with_many.first, Side::LHS);
453         }
454         if (uses_with_many.second.size() >= how_many_is_many) {
455           batch_side(uses_with_many.second, Side::RHS);
456         }
457       }
458     } else {
459       for (Block* subblock : node->blocks()) {
460         BatchMMSide(subblock, alias_db);
461       }
462     }
463   }
464 }
465 
hasMMOperators(std::shared_ptr<Graph> & graph)466 static bool hasMMOperators(std::shared_ptr<Graph>& graph) {
467   DepthFirstGraphNodeIterator it(graph);
468   Node* n = nullptr;
469   while ((n = it.next()) != nullptr) {
470     if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
471       return true;
472     }
473   }
474   return false;
475 }
476 
BatchMM(std::shared_ptr<Graph> & graph)477 void BatchMM(std::shared_ptr<Graph>& graph) {
478   if (!hasMMOperators(graph)) {
479     return;
480   }
481   AliasDb alias_db(graph);
482   BatchMMTreeReduce(graph->block(), alias_db);
483   BatchMMSide(graph->block(), alias_db);
484   EliminateDeadCode(graph);
485   // It's possible that transpose rearrangements have created sequences of
486   // consecutive transposes that didn't exist before.
487 
488   // tensor type properties are not guaranteed to be correct
489   PeepholeOptimize(graph, /*disable_shape_peepholes*/ true);
490 }
491 
492 } // namespace torch::jit
493