1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/runtime/kernel/kernel_includes.h>
10
11 #include <executorch/kernels/optimized/blas/CPUBlas.h>
12
13 // Performs a batch matrix-matrix product of matrices stored in input and mat2.
14
15 // input and mat2 must be 3-D tensors each containing the same number of
16 // matrices.
17
18 // If input is a (b \times n \times m)(b×n×m) tensor, mat2 is a (b \times m
19 // \times p)(b×m×p) tensor, out will be a (b \times n \times p)(b×n×p) tensor.
20
21 // Note: This function does not broadcast. For broadcasting matrix products, see
22 // matmul().
23 namespace torch {
24 namespace executor {
25 namespace native {
26
27 using Tensor = exec_aten::Tensor;
28
29 namespace {
30
31 // Verifies that the parameters are valid.
check_bmm_out_args(const Tensor & self,const Tensor & mat2,Tensor & out)32 bool check_bmm_out_args(const Tensor& self, const Tensor& mat2, Tensor& out) {
33 // Ensure dimensions is 3 for all input and out
34 ET_LOG_MSG_AND_RETURN_IF_FALSE(
35 self.dim() == mat2.dim(),
36 "self.dim() %zd != mat2.dim() %zd",
37 self.dim(),
38 mat2.dim());
39 ET_LOG_MSG_AND_RETURN_IF_FALSE(
40 self.dim() == out.dim(),
41 "self.dim() %zd != out.dim() %zd",
42 self.dim(),
43 out.dim());
44 ET_LOG_MSG_AND_RETURN_IF_FALSE(
45 self.dim() == 3, "self.dim() %zd != 3", self.dim());
46 // Ensure batch larger than or equals to 0
47 ET_LOG_MSG_AND_RETURN_IF_FALSE(
48 self.size(0) >= 0, "self.size(0) %zd < 0", self.size(0));
49 // Ensure batches are the same
50 ET_LOG_MSG_AND_RETURN_IF_FALSE(
51 self.size(0) == mat2.size(0),
52 "self.size(0) %zd != mat2.size(0) %zd",
53 self.size(0),
54 mat2.size(0));
55 ET_LOG_MSG_AND_RETURN_IF_FALSE(
56 self.size(0) == out.size(0),
57 "self.size(0) %zd != out.size(0) %zd",
58 self.size(0),
59 out.size(0));
60 // Ensure the out size is compatible with input tensors
61 ET_LOG_MSG_AND_RETURN_IF_FALSE(
62 mat2.size(2) == out.size(2),
63 "mat2.size(2) %zd != out.size(2) %zd",
64 mat2.size(2),
65 out.size(2));
66 ET_LOG_MSG_AND_RETURN_IF_FALSE(
67 self.size(1) == out.size(1),
68 "self.size(1) %zd != out.size(1) %zd",
69 self.size(1),
70 out.size(1));
71
72 // Ensure that all tensors share a dtype
73 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(self, mat2, out));
74
75 return true;
76 }
77
78 template <typename CTYPE>
bmm_kernel(const Tensor & self,const Tensor & mat2,Tensor & out)79 void bmm_kernel(const Tensor& self, const Tensor& mat2, Tensor& out) {
80 using executorch::cpublas::TransposeType;
81
82 if (self.numel() == 0 || mat2.numel() == 0 || out.numel() == 0) {
83 return;
84 }
85
86 const CTYPE* b_data = self.const_data_ptr<CTYPE>();
87 const CTYPE* a_data = mat2.const_data_ptr<CTYPE>();
88 CTYPE* c_data = out.mutable_data_ptr<CTYPE>();
89
90 int64_t batch_size = self.size(0);
91 int64_t n = self.size(1);
92 int64_t k = self.size(2);
93 int64_t m = mat2.size(2);
94
95 for (int i = 0; i < batch_size; ++i) {
96 const CTYPE* a = a_data + i * m * k;
97 const CTYPE* b = b_data + i * k * n;
98 CTYPE* c = c_data + i * m * n;
99
100 // clang-format off
101 executorch::cpublas::gemm(
102 TransposeType::NoTranspose, TransposeType::NoTranspose,
103 m, n, k,
104 static_cast<CTYPE>(1),
105 a, m,
106 b, k,
107 static_cast<CTYPE>(0),
108 c, m);
109 // clang-format on
110 }
111 }
112
resize_out_tensor(const Tensor & self,const Tensor & mat2,Tensor & out)113 Error resize_out_tensor(const Tensor& self, const Tensor& mat2, Tensor& out) {
114 exec_aten::SizesType expected_output_size[kTensorDimensionLimit];
115
116 const size_t m_dim = self.dim() - 2;
117 const size_t n_dim = self.dim() - 1;
118
119 for (size_t i = 0; i < m_dim; i++) {
120 expected_output_size[i] = self.size(i);
121 }
122
123 if (m_dim >= self.dim() || n_dim >= mat2.dim()) {
124 ET_LOG(Error, "Incompatible matrix multiply dimensions.");
125 return Error::InvalidArgument;
126 }
127
128 expected_output_size[m_dim] = self.size(m_dim);
129 expected_output_size[n_dim] = mat2.size(n_dim);
130
131 ArrayRef<exec_aten::SizesType> output_size{
132 expected_output_size, static_cast<size_t>(out.dim())};
133
134 return resize_tensor(out, output_size);
135 }
136 } // namespace
137
138 // bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
opt_bmm_out(KernelRuntimeContext & context,const Tensor & self,const Tensor & mat2,Tensor & out)139 Tensor& opt_bmm_out(
140 KernelRuntimeContext& context,
141 const Tensor& self,
142 const Tensor& mat2,
143 Tensor& out) {
144 (void)context;
145
146 ET_KERNEL_CHECK(
147 context,
148 resize_out_tensor(self, mat2, out) == Error::Ok,
149 InvalidArgument,
150 out);
151 ET_KERNEL_CHECK(
152 context, check_bmm_out_args(self, mat2, out), InvalidArgument, out);
153
154 #define BMM_TENSOR(ctype, dtype) \
155 case ScalarType::dtype: \
156 bmm_kernel<ctype>(self, mat2, out); \
157 break;
158
159 auto scalar_type = self.scalar_type();
160 switch (scalar_type) {
161 ET_FORALL_REAL_TYPES_AND(Half, BMM_TENSOR)
162 default:
163 ET_CHECK_MSG(
164 false, "Unhandled dtype %" PRId8, static_cast<int8_t>(scalar_type));
165 }
166 #undef BMM_TENSOR
167
168 return out;
169 }
170
171 } // namespace native
172 } // namespace executor
173 } // namespace torch
174