1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstring>
10
11 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12 #include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
13
14 namespace torch {
15 namespace executor {
16
17 using Tensor = exec_aten::Tensor;
18
check_addmm_args(const Tensor & in,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & out)19 bool check_addmm_args(
20 const Tensor& in,
21 const Tensor& mat1,
22 const Tensor& mat2,
23 const Scalar& beta,
24 const Scalar& alpha,
25 Tensor& out) {
26 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(mat1, 2));
27 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(mat2, 2));
28 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 2));
29
30 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mat1, mat2));
31 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
32
33 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(mat1, 1, mat2, 0));
34
35 return true;
36 }
37
check_bmm_args(const Tensor & in,const Tensor & mat2,Tensor & out)38 bool check_bmm_args(const Tensor& in, const Tensor& mat2, Tensor& out) {
39 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 3));
40 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(mat2, 3));
41 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 3));
42
43 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mat2, out));
44
45 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 0, mat2, 0));
46 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 2, mat2, 1));
47
48 return true;
49 }
50
get_bmm_out_target_size(const Tensor & mat1,const Tensor & mat2,Tensor::SizesType * out_sizes,size_t * out_ndim)51 void get_bmm_out_target_size(
52 const Tensor& mat1,
53 const Tensor& mat2,
54 Tensor::SizesType* out_sizes,
55 size_t* out_ndim) {
56 *out_ndim = 3;
57 out_sizes[0] = mat1.size(0);
58 out_sizes[1] = mat1.size(1);
59 out_sizes[2] = mat2.size(2);
60 }
61
check_mm_args(const Tensor & in,const Tensor & mat2,Tensor & out)62 bool check_mm_args(const Tensor& in, const Tensor& mat2, Tensor& out) {
63 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2));
64 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(mat2, 2));
65 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 2));
66
67 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mat2, out));
68
69 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 1, mat2, 0));
70
71 return true;
72 }
73
check_linear_args(const Tensor & in,const Tensor & mat2,Tensor & out)74 bool check_linear_args(const Tensor& in, const Tensor& mat2, Tensor& out) {
75 ET_LOG_AND_RETURN_IF_FALSE(in.dim() == out.dim());
76 ET_LOG_AND_RETURN_IF_FALSE(in.dim() >= 2);
77 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(mat2, 2));
78
79 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mat2, out));
80
81 ET_LOG_AND_RETURN_IF_FALSE(
82 tensors_have_same_size_at_dims(in, in.dim() - 1, mat2, 1));
83
84 return true;
85 }
86
get_mm_out_target_size(const Tensor & mat1,const Tensor & mat2,Tensor::SizesType * out_sizes,size_t * out_ndim)87 void get_mm_out_target_size(
88 const Tensor& mat1,
89 const Tensor& mat2,
90 Tensor::SizesType* out_sizes,
91 size_t* out_ndim) {
92 *out_ndim = 2;
93 out_sizes[0] = mat1.size(0);
94 out_sizes[1] = mat2.size(1);
95 }
96
get_linear_out_target_size(const Tensor & mat1,const Tensor & mat2,Tensor::SizesType * out_sizes,size_t * out_ndim)97 void get_linear_out_target_size(
98 const Tensor& mat1,
99 const Tensor& mat2,
100 Tensor::SizesType* out_sizes,
101 size_t* out_ndim) {
102 *out_ndim = mat1.dim();
103 for (int ii = 0; ii < mat1.dim() - 1; ++ii) {
104 out_sizes[ii] = mat1.sizes()[ii];
105 }
106 out_sizes[mat1.dim() - 1] = mat2.size(0);
107 }
108
109 } // namespace executor
110 } // namespace torch
111