xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/operators/matmul.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
2 #include <torch/csrc/jit/tensorexpr/operators/matmul.h>
3 
4 namespace torch::jit::tensorexpr {
5 
computeMatmul(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)6 Tensor computeMatmul(
7     const std::vector<ArgValue>& inputs,
8     const std::vector<ExprHandle>& outputShape,
9     const std::vector<ExprHandle>& outputStrides,
10     const std::optional<ScalarType>& outputType,
11     at::Device device) {
12   Dtype dtype = kFloat;
13   if (outputType) {
14     dtype = Dtype(*outputType);
15   }
16   BufHandle ResultBuf("matmul", outputShape, dtype);
17   const BufHandle a = std::get<BufHandle>(inputs[0]);
18   const BufHandle b = std::get<BufHandle>(inputs[1]);
19 
20   auto size_a = a.dims();
21   auto size_b = b.dims();
22   // We currently only support rank 2 matmuls
23   TORCH_INTERNAL_ASSERT(size_a.size() == 2 && size_b.size() == 2);
24   auto total_size =
25       to<LongImm>(IRSimplifier::simplify(
26                       cast<int64_t>(size_a[0]) * cast<int64_t>(size_a[1]) *
27                       cast<int64_t>(size_b[1]))
28                       .node());
29 
30   // For small sizes, where N*M*K < 1000, lower matmul to a naive 3-level
31   // loopnest. The number is not tuned very carefully, and in future we should
32   // fine-tune it as well as we should add more advanced native TE lowerings for
33   // matmuls. For bigger sizes we generate a TE ExternalCall, which would call
34   // an aten::matmul.
35   // Native, even naive, lowering is beneficial when the sizes are small because
36   // it allows to eliminate dispatch overhead.
37   if (total_size && total_size->value() < 1000) {
38     return Reduce(
39         "nnc_matmul",
40         {size_a[0], size_b[1]},
41         Sum(),
42         [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
43           return Load::make(a, {m, k}) * Load::make(b, {k, n});
44         },
45         {size_a[1]});
46   } else {
47     return Tensor(
48         ResultBuf.node(),
49         ExternalCall::make(ResultBuf, "nnc_aten_matmul", {a, b}, {}));
50   }
51 }
52 
computeAddMM(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)53 Tensor computeAddMM(
54     const std::vector<ArgValue>& inputs,
55     const std::vector<ExprHandle>& outputShape,
56     const std::vector<ExprHandle>& outputStrides,
57     const std::optional<ScalarType>& outputType,
58     at::Device device) {
59   Dtype dtype = kFloat;
60   if (outputType) {
61     dtype = Dtype(*outputType);
62   }
63   BufHandle ResultBuf("addmm", outputShape, dtype);
64   return Tensor(
65       ResultBuf.node(),
66       ExternalCall::make(
67           ResultBuf,
68           "nnc_aten_addmm",
69           {std::get<BufHandle>(inputs[0]),
70            std::get<BufHandle>(inputs[1]),
71            std::get<BufHandle>(inputs[2])},
72           {std::get<int64_t>(inputs[3]),
73            std::get<int64_t>(
74                inputs[4])})); // TODO: handle other dtypes of alpha and beta
75 }
76 
77 } // namespace torch::jit::tensorexpr
78