1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2017-2021 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust #include "GEMM.h"
25*c217d954SCole Faust
26*c217d954SCole Faust #include "arm_compute/core/Helpers.h"
27*c217d954SCole Faust #include "arm_compute/core/Types.h"
28*c217d954SCole Faust
29*c217d954SCole Faust namespace arm_compute
30*c217d954SCole Faust {
31*c217d954SCole Faust namespace test
32*c217d954SCole Faust {
33*c217d954SCole Faust namespace validation
34*c217d954SCole Faust {
35*c217d954SCole Faust namespace reference
36*c217d954SCole Faust {
37*c217d954SCole Faust template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
gemm(const SimpleTensor<T> & a,const SimpleTensor<T> & b,const SimpleTensor<T> & c,float alpha,float beta)38*c217d954SCole Faust SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta)
39*c217d954SCole Faust {
40*c217d954SCole Faust // Create reference
41*c217d954SCole Faust SimpleTensor<T> dst{ c.shape(), c.data_type(), 1 };
42*c217d954SCole Faust
43*c217d954SCole Faust // Compute reference
44*c217d954SCole Faust const int M = a.shape().y();
45*c217d954SCole Faust const int N = b.shape().x();
46*c217d954SCole Faust const int K = a.shape().x();
47*c217d954SCole Faust const int D = a.shape().z(); // Number of matrices in a batch
48*c217d954SCole Faust const int W = a.shape()[3]; // Number of batched-gemm (Winograd case)
49*c217d954SCole Faust
50*c217d954SCole Faust const int a_stride_z = K * M;
51*c217d954SCole Faust const int a_stride_w = K * M * D;
52*c217d954SCole Faust
53*c217d954SCole Faust const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions
54*c217d954SCole Faust int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions
55*c217d954SCole Faust
56*c217d954SCole Faust // Note: There are 3 gemm types: batched-gemm, multi-gemm, and batched of multi-gemms. The third dimension of tensor b is overloaded when tensor b has exactly 3 dimensions:
57*c217d954SCole Faust // it can be either number of batches or multis. Batched-GEMM computation is detected only when the third dimension of "a" and "c" tensors is 1 and the number of dimensions is 4
58*c217d954SCole Faust const bool is_batched_gemm = b.shape().num_dimensions() == 3 && a.shape().num_dimensions() == 4 && c.shape().num_dimensions() == 4 && a.shape()[2] == 1 && c.shape()[2] == 1;
59*c217d954SCole Faust
60*c217d954SCole Faust // Batched-GEMM
61*c217d954SCole Faust if(is_batched_gemm)
62*c217d954SCole Faust {
63*c217d954SCole Faust b_stride_w = b_stride_z;
64*c217d954SCole Faust }
65*c217d954SCole Faust
66*c217d954SCole Faust const int c_stride_z = N * M;
67*c217d954SCole Faust const int c_stride_w = N * M * D;
68*c217d954SCole Faust
69*c217d954SCole Faust #if defined(_OPENMP) && !(defined(__arm__) && defined(__ANDROID__))
70*c217d954SCole Faust #pragma omp parallel for collapse(2)
71*c217d954SCole Faust #endif /* _OPENMP */
72*c217d954SCole Faust for(int w = 0; w < W; ++w)
73*c217d954SCole Faust {
74*c217d954SCole Faust for(int depth = 0; depth < D; ++depth)
75*c217d954SCole Faust {
76*c217d954SCole Faust const int base_addr_a = depth * a_stride_z + w * a_stride_w;
77*c217d954SCole Faust const int base_addr_b = depth * b_stride_z + w * b_stride_w;
78*c217d954SCole Faust const int base_addr_c = depth * c_stride_z + w * c_stride_w;
79*c217d954SCole Faust
80*c217d954SCole Faust for(int row = 0; row < M; ++row)
81*c217d954SCole Faust {
82*c217d954SCole Faust for(int col = 0; col < N; ++col)
83*c217d954SCole Faust {
84*c217d954SCole Faust T acc(0);
85*c217d954SCole Faust
86*c217d954SCole Faust for(int k = 0; k < K; ++k)
87*c217d954SCole Faust {
88*c217d954SCole Faust acc += a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N];
89*c217d954SCole Faust }
90*c217d954SCole Faust
91*c217d954SCole Faust // Finalize the result: alpha * A * B + beta * C
92*c217d954SCole Faust dst[base_addr_c + col + row * N] = alpha * acc + beta * c[base_addr_c + col + row * N];
93*c217d954SCole Faust }
94*c217d954SCole Faust }
95*c217d954SCole Faust }
96*c217d954SCole Faust }
97*c217d954SCole Faust
98*c217d954SCole Faust return dst;
99*c217d954SCole Faust }
100*c217d954SCole Faust
101*c217d954SCole Faust template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
gemm_mixed_precision(const SimpleTensor<T> & a,const SimpleTensor<T> & b,const SimpleTensor<T> & c,float alpha,float beta)102*c217d954SCole Faust SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta)
103*c217d954SCole Faust {
104*c217d954SCole Faust // GEMM mixed-precision combines F32 accumulators with F16 multiplications
105*c217d954SCole Faust // Create reference
106*c217d954SCole Faust SimpleTensor<T> dst{ c.shape(), c.data_type(), 1 };
107*c217d954SCole Faust
108*c217d954SCole Faust // Compute reference
109*c217d954SCole Faust const int M = a.shape().y();
110*c217d954SCole Faust const int N = b.shape().x();
111*c217d954SCole Faust const int K = a.shape().x();
112*c217d954SCole Faust const int D = a.shape().z(); // Number of matrices in a batch
113*c217d954SCole Faust const int W = a.shape()[3]; // Number of batched-gemm (Winograd case)
114*c217d954SCole Faust
115*c217d954SCole Faust const int a_stride_z = K * M;
116*c217d954SCole Faust const int a_stride_w = K * M * D;
117*c217d954SCole Faust
118*c217d954SCole Faust const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions
119*c217d954SCole Faust int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions
120*c217d954SCole Faust
121*c217d954SCole Faust // Note: There are 3 gemm types: batched-gemm, multi-gemm, and batched of multi-gemms. The third dimension of tensor b is overloaded when tensor b has exactly 3 dimensions:
122*c217d954SCole Faust // it can be either number of batches or multis. Batched-GEMM computation is detected only when the third dimension of "a" and "c" tensors is 1 and the number of dimensions is 4
123*c217d954SCole Faust const bool is_batched_gemm = b.shape().num_dimensions() == 3 && a.shape().num_dimensions() == 4 && c.shape().num_dimensions() == 4 && a.shape()[2] == 1 && c.shape()[2] == 1;
124*c217d954SCole Faust
125*c217d954SCole Faust // Batched-GEMM
126*c217d954SCole Faust if(is_batched_gemm)
127*c217d954SCole Faust {
128*c217d954SCole Faust b_stride_w = b_stride_z;
129*c217d954SCole Faust }
130*c217d954SCole Faust
131*c217d954SCole Faust const int c_stride_z = N * M;
132*c217d954SCole Faust const int c_stride_w = N * M * D;
133*c217d954SCole Faust
134*c217d954SCole Faust #if defined(_OPENMP) && !(defined(__arm__) && defined(__ANDROID__))
135*c217d954SCole Faust #pragma omp parallel for collapse(2)
136*c217d954SCole Faust #endif /* _OPENMP */
137*c217d954SCole Faust for(int w = 0; w < W; ++w)
138*c217d954SCole Faust {
139*c217d954SCole Faust for(int depth = 0; depth < D; ++depth)
140*c217d954SCole Faust {
141*c217d954SCole Faust const int base_addr_a = depth * a_stride_z + w * a_stride_w;
142*c217d954SCole Faust const int base_addr_b = depth * b_stride_z + w * b_stride_w;
143*c217d954SCole Faust const int base_addr_c = depth * c_stride_z + w * c_stride_w;
144*c217d954SCole Faust
145*c217d954SCole Faust for(int row = 0; row < M; ++row)
146*c217d954SCole Faust {
147*c217d954SCole Faust for(int col = 0; col < N; ++col)
148*c217d954SCole Faust {
149*c217d954SCole Faust float acc(0);
150*c217d954SCole Faust
151*c217d954SCole Faust for(int k = 0; k < K; ++k)
152*c217d954SCole Faust {
153*c217d954SCole Faust acc += static_cast<float>(a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N]);
154*c217d954SCole Faust }
155*c217d954SCole Faust
156*c217d954SCole Faust // Finalize the result: alpha * A * B + beta * C
157*c217d954SCole Faust dst[base_addr_c + col + row * N] = static_cast<T>(alpha * acc + beta * c[base_addr_c + col + row * N]);
158*c217d954SCole Faust }
159*c217d954SCole Faust }
160*c217d954SCole Faust }
161*c217d954SCole Faust }
162*c217d954SCole Faust
163*c217d954SCole Faust return dst;
164*c217d954SCole Faust }
165*c217d954SCole Faust
166*c217d954SCole Faust template SimpleTensor<float> gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta);
167*c217d954SCole Faust template SimpleTensor<half> gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
168*c217d954SCole Faust template SimpleTensor<half> gemm_mixed_precision(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
169*c217d954SCole Faust } // namespace reference
170*c217d954SCole Faust } // namespace validation
171*c217d954SCole Faust } // namespace test
172*c217d954SCole Faust } // namespace arm_compute
173