xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/matmul_ops_util.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstring>
10 
11 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12 #include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
13 
14 namespace torch {
15 namespace executor {
16 
17 using Tensor = exec_aten::Tensor;
18 
check_addmm_args(const Tensor & in,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & out)19 bool check_addmm_args(
20     const Tensor& in,
21     const Tensor& mat1,
22     const Tensor& mat2,
23     const Scalar& beta,
24     const Scalar& alpha,
25     Tensor& out) {
26   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(mat1, 2));
27   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(mat2, 2));
28   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 2));
29 
30   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mat1, mat2));
31   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
32 
33   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(mat1, 1, mat2, 0));
34 
35   return true;
36 }
37 
check_bmm_args(const Tensor & in,const Tensor & mat2,Tensor & out)38 bool check_bmm_args(const Tensor& in, const Tensor& mat2, Tensor& out) {
39   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 3));
40   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(mat2, 3));
41   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 3));
42 
43   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mat2, out));
44 
45   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 0, mat2, 0));
46   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 2, mat2, 1));
47 
48   return true;
49 }
50 
get_bmm_out_target_size(const Tensor & mat1,const Tensor & mat2,Tensor::SizesType * out_sizes,size_t * out_ndim)51 void get_bmm_out_target_size(
52     const Tensor& mat1,
53     const Tensor& mat2,
54     Tensor::SizesType* out_sizes,
55     size_t* out_ndim) {
56   *out_ndim = 3;
57   out_sizes[0] = mat1.size(0);
58   out_sizes[1] = mat1.size(1);
59   out_sizes[2] = mat2.size(2);
60 }
61 
check_mm_args(const Tensor & in,const Tensor & mat2,Tensor & out)62 bool check_mm_args(const Tensor& in, const Tensor& mat2, Tensor& out) {
63   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2));
64   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(mat2, 2));
65   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 2));
66 
67   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mat2, out));
68 
69   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 1, mat2, 0));
70 
71   return true;
72 }
73 
check_linear_args(const Tensor & in,const Tensor & mat2,Tensor & out)74 bool check_linear_args(const Tensor& in, const Tensor& mat2, Tensor& out) {
75   ET_LOG_AND_RETURN_IF_FALSE(in.dim() == out.dim());
76   ET_LOG_AND_RETURN_IF_FALSE(in.dim() >= 2);
77   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(mat2, 2));
78 
79   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mat2, out));
80 
81   ET_LOG_AND_RETURN_IF_FALSE(
82       tensors_have_same_size_at_dims(in, in.dim() - 1, mat2, 1));
83 
84   return true;
85 }
86 
get_mm_out_target_size(const Tensor & mat1,const Tensor & mat2,Tensor::SizesType * out_sizes,size_t * out_ndim)87 void get_mm_out_target_size(
88     const Tensor& mat1,
89     const Tensor& mat2,
90     Tensor::SizesType* out_sizes,
91     size_t* out_ndim) {
92   *out_ndim = 2;
93   out_sizes[0] = mat1.size(0);
94   out_sizes[1] = mat2.size(1);
95 }
96 
get_linear_out_target_size(const Tensor & mat1,const Tensor & mat2,Tensor::SizesType * out_sizes,size_t * out_ndim)97 void get_linear_out_target_size(
98     const Tensor& mat1,
99     const Tensor& mat2,
100     Tensor::SizesType* out_sizes,
101     size_t* out_ndim) {
102   *out_ndim = mat1.dim();
103   for (int ii = 0; ii < mat1.dim() - 1; ++ii) {
104     out_sizes[ii] = mat1.sizes()[ii];
105   }
106   out_sizes[mat1.dim() - 1] = mat2.size(0);
107 }
108 
109 } // namespace executor
110 } // namespace torch
111