1 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
2 #include <torch/csrc/jit/tensorexpr/operators/matmul.h>
3
4 namespace torch::jit::tensorexpr {
5
computeMatmul(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)6 Tensor computeMatmul(
7 const std::vector<ArgValue>& inputs,
8 const std::vector<ExprHandle>& outputShape,
9 const std::vector<ExprHandle>& outputStrides,
10 const std::optional<ScalarType>& outputType,
11 at::Device device) {
12 Dtype dtype = kFloat;
13 if (outputType) {
14 dtype = Dtype(*outputType);
15 }
16 BufHandle ResultBuf("matmul", outputShape, dtype);
17 const BufHandle a = std::get<BufHandle>(inputs[0]);
18 const BufHandle b = std::get<BufHandle>(inputs[1]);
19
20 auto size_a = a.dims();
21 auto size_b = b.dims();
22 // We currently only support rank 2 matmuls
23 TORCH_INTERNAL_ASSERT(size_a.size() == 2 && size_b.size() == 2);
24 auto total_size =
25 to<LongImm>(IRSimplifier::simplify(
26 cast<int64_t>(size_a[0]) * cast<int64_t>(size_a[1]) *
27 cast<int64_t>(size_b[1]))
28 .node());
29
30 // For small sizes, where N*M*K < 1000, lower matmul to a naive 3-level
31 // loopnest. The number is not tuned very carefully, and in future we should
32 // fine-tune it as well as we should add more advanced native TE lowerings for
33 // matmuls. For bigger sizes we generate a TE ExternalCall, which would call
34 // an aten::matmul.
35 // Native, even naive, lowering is beneficial when the sizes are small because
36 // it allows to eliminate dispatch overhead.
37 if (total_size && total_size->value() < 1000) {
38 return Reduce(
39 "nnc_matmul",
40 {size_a[0], size_b[1]},
41 Sum(),
42 [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
43 return Load::make(a, {m, k}) * Load::make(b, {k, n});
44 },
45 {size_a[1]});
46 } else {
47 return Tensor(
48 ResultBuf.node(),
49 ExternalCall::make(ResultBuf, "nnc_aten_matmul", {a, b}, {}));
50 }
51 }
52
computeAddMM(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)53 Tensor computeAddMM(
54 const std::vector<ArgValue>& inputs,
55 const std::vector<ExprHandle>& outputShape,
56 const std::vector<ExprHandle>& outputStrides,
57 const std::optional<ScalarType>& outputType,
58 at::Device device) {
59 Dtype dtype = kFloat;
60 if (outputType) {
61 dtype = Dtype(*outputType);
62 }
63 BufHandle ResultBuf("addmm", outputShape, dtype);
64 return Tensor(
65 ResultBuf.node(),
66 ExternalCall::make(
67 ResultBuf,
68 "nnc_aten_addmm",
69 {std::get<BufHandle>(inputs[0]),
70 std::get<BufHandle>(inputs[1]),
71 std::get<BufHandle>(inputs[2])},
72 {std::get<int64_t>(inputs[3]),
73 std::get<int64_t>(
74 inputs[4])})); // TODO: handle other dtypes of alpha and beta
75 }
76
77 } // namespace torch::jit::tensorexpr
78