xref: /aosp_15_r20/external/executorch/kernels/optimized/cpu/op_bmm.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/kernel/kernel_includes.h>
10 
11 #include <executorch/kernels/optimized/blas/CPUBlas.h>
12 
13 // Performs a batch matrix-matrix product of matrices stored in input and mat2.
14 
15 // input and mat2 must be 3-D tensors each containing the same number of
16 // matrices.
17 
18 // If input is a (b \times n \times m)(b×n×m) tensor, mat2 is a (b \times m
19 // \times p)(b×m×p) tensor, out will be a (b \times n \times p)(b×n×p) tensor.
20 
21 // Note: This function does not broadcast. For broadcasting matrix products, see
22 // matmul().
23 namespace torch {
24 namespace executor {
25 namespace native {
26 
27 using Tensor = exec_aten::Tensor;
28 
29 namespace {
30 
31 // Verifies that the parameters are valid.
check_bmm_out_args(const Tensor & self,const Tensor & mat2,Tensor & out)32 bool check_bmm_out_args(const Tensor& self, const Tensor& mat2, Tensor& out) {
33   // Ensure dimensions is 3 for all input and out
34   ET_LOG_MSG_AND_RETURN_IF_FALSE(
35       self.dim() == mat2.dim(),
36       "self.dim() %zd != mat2.dim() %zd",
37       self.dim(),
38       mat2.dim());
39   ET_LOG_MSG_AND_RETURN_IF_FALSE(
40       self.dim() == out.dim(),
41       "self.dim() %zd != out.dim() %zd",
42       self.dim(),
43       out.dim());
44   ET_LOG_MSG_AND_RETURN_IF_FALSE(
45       self.dim() == 3, "self.dim() %zd != 3", self.dim());
46   // Ensure batch larger than or equals to 0
47   ET_LOG_MSG_AND_RETURN_IF_FALSE(
48       self.size(0) >= 0, "self.size(0) %zd < 0", self.size(0));
49   // Ensure batches are the same
50   ET_LOG_MSG_AND_RETURN_IF_FALSE(
51       self.size(0) == mat2.size(0),
52       "self.size(0) %zd != mat2.size(0) %zd",
53       self.size(0),
54       mat2.size(0));
55   ET_LOG_MSG_AND_RETURN_IF_FALSE(
56       self.size(0) == out.size(0),
57       "self.size(0) %zd != out.size(0) %zd",
58       self.size(0),
59       out.size(0));
60   // Ensure the out size is compatible with input tensors
61   ET_LOG_MSG_AND_RETURN_IF_FALSE(
62       mat2.size(2) == out.size(2),
63       "mat2.size(2) %zd != out.size(2) %zd",
64       mat2.size(2),
65       out.size(2));
66   ET_LOG_MSG_AND_RETURN_IF_FALSE(
67       self.size(1) == out.size(1),
68       "self.size(1) %zd != out.size(1) %zd",
69       self.size(1),
70       out.size(1));
71 
72   // Ensure that all tensors share a dtype
73   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(self, mat2, out));
74 
75   return true;
76 }
77 
78 template <typename CTYPE>
bmm_kernel(const Tensor & self,const Tensor & mat2,Tensor & out)79 void bmm_kernel(const Tensor& self, const Tensor& mat2, Tensor& out) {
80   using executorch::cpublas::TransposeType;
81 
82   if (self.numel() == 0 || mat2.numel() == 0 || out.numel() == 0) {
83     return;
84   }
85 
86   const CTYPE* b_data = self.const_data_ptr<CTYPE>();
87   const CTYPE* a_data = mat2.const_data_ptr<CTYPE>();
88   CTYPE* c_data = out.mutable_data_ptr<CTYPE>();
89 
90   int64_t batch_size = self.size(0);
91   int64_t n = self.size(1);
92   int64_t k = self.size(2);
93   int64_t m = mat2.size(2);
94 
95   for (int i = 0; i < batch_size; ++i) {
96     const CTYPE* a = a_data + i * m * k;
97     const CTYPE* b = b_data + i * k * n;
98     CTYPE* c = c_data + i * m * n;
99 
100     // clang-format off
101     executorch::cpublas::gemm(
102         TransposeType::NoTranspose, TransposeType::NoTranspose,
103         m, n, k,
104         static_cast<CTYPE>(1),
105         a, m,
106         b, k,
107         static_cast<CTYPE>(0),
108         c, m);
109     // clang-format on
110   }
111 }
112 
resize_out_tensor(const Tensor & self,const Tensor & mat2,Tensor & out)113 Error resize_out_tensor(const Tensor& self, const Tensor& mat2, Tensor& out) {
114   exec_aten::SizesType expected_output_size[kTensorDimensionLimit];
115 
116   const size_t m_dim = self.dim() - 2;
117   const size_t n_dim = self.dim() - 1;
118 
119   for (size_t i = 0; i < m_dim; i++) {
120     expected_output_size[i] = self.size(i);
121   }
122 
123   if (m_dim >= self.dim() || n_dim >= mat2.dim()) {
124     ET_LOG(Error, "Incompatible matrix multiply dimensions.");
125     return Error::InvalidArgument;
126   }
127 
128   expected_output_size[m_dim] = self.size(m_dim);
129   expected_output_size[n_dim] = mat2.size(n_dim);
130 
131   ArrayRef<exec_aten::SizesType> output_size{
132       expected_output_size, static_cast<size_t>(out.dim())};
133 
134   return resize_tensor(out, output_size);
135 }
136 } // namespace
137 
138 // bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
opt_bmm_out(KernelRuntimeContext & context,const Tensor & self,const Tensor & mat2,Tensor & out)139 Tensor& opt_bmm_out(
140     KernelRuntimeContext& context,
141     const Tensor& self,
142     const Tensor& mat2,
143     Tensor& out) {
144   (void)context;
145 
146   ET_KERNEL_CHECK(
147       context,
148       resize_out_tensor(self, mat2, out) == Error::Ok,
149       InvalidArgument,
150       out);
151   ET_KERNEL_CHECK(
152       context, check_bmm_out_args(self, mat2, out), InvalidArgument, out);
153 
154 #define BMM_TENSOR(ctype, dtype)        \
155   case ScalarType::dtype:               \
156     bmm_kernel<ctype>(self, mat2, out); \
157     break;
158 
159   auto scalar_type = self.scalar_type();
160   switch (scalar_type) {
161     ET_FORALL_REAL_TYPES_AND(Half, BMM_TENSOR)
162     default:
163       ET_CHECK_MSG(
164           false, "Unhandled dtype %" PRId8, static_cast<int8_t>(scalar_type));
165   }
166 #undef BMM_TENSOR
167 
168   return out;
169 }
170 
171 } // namespace native
172 } // namespace executor
173 } // namespace torch
174