xref: /aosp_15_r20/external/gemmlowp/internal/kernel_reference.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han // kernel_reference.h: a reference kernel for CPU architectures where we don't
16*5f39d1b3SJooyung Han // have optimized kernels yet. Also useful for testing, as it's templatized
17*5f39d1b3SJooyung Han // to have any arbitrary format, allowing tests to cover all sorts of corner
18*5f39d1b3SJooyung Han // cases.
19*5f39d1b3SJooyung Han 
20*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_
21*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_
22*5f39d1b3SJooyung Han 
23*5f39d1b3SJooyung Han #include "kernel.h"
24*5f39d1b3SJooyung Han 
25*5f39d1b3SJooyung Han #include <cstdio>
26*5f39d1b3SJooyung Han #include <cstring>
27*5f39d1b3SJooyung Han 
28*5f39d1b3SJooyung Han namespace gemmlowp {
29*5f39d1b3SJooyung Han 
30*5f39d1b3SJooyung Han // This kernel is templatized in an arbitrary Format template parameter,
31*5f39d1b3SJooyung Han // allowing it to have any arbitrary format.
32*5f39d1b3SJooyung Han template <typename tFormat>
33*5f39d1b3SJooyung Han struct ReferenceKernel : KernelBase {
34*5f39d1b3SJooyung Han   typedef tFormat Format;
35*5f39d1b3SJooyung Han 
NameReferenceKernel36*5f39d1b3SJooyung Han   const char* Name() const override {
37*5f39d1b3SJooyung Han     static char buf[256];
38*5f39d1b3SJooyung Han     snprintf(buf, sizeof(buf),
39*5f39d1b3SJooyung Han              "reference(Lhs: %d cells %dx%d %s, Rhs: %d cells %dx%d %s)",
40*5f39d1b3SJooyung Han              Format::Lhs::kCells, Format::Lhs::Cell::kWidth,
41*5f39d1b3SJooyung Han              Format::Lhs::Cell::kDepth,
42*5f39d1b3SJooyung Han              CellOrderName(Format::Lhs::Cell::kOrder), Format::Rhs::kCells,
43*5f39d1b3SJooyung Han              Format::Rhs::Cell::kDepth, Format::Rhs::Cell::kWidth,
44*5f39d1b3SJooyung Han              CellOrderName(Format::Rhs::Cell::kOrder));
45*5f39d1b3SJooyung Han     return buf;
46*5f39d1b3SJooyung Han   }
47*5f39d1b3SJooyung Han 
RunReferenceKernel48*5f39d1b3SJooyung Han   void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
49*5f39d1b3SJooyung Han            std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
50*5f39d1b3SJooyung Han            const std::uint8_t* rhs_ptr, std::size_t start_depth,
51*5f39d1b3SJooyung Han            std::size_t run_depth) const override {
52*5f39d1b3SJooyung Han     std::int32_t accumulator[Format::kRows * Format::kCols];
53*5f39d1b3SJooyung Han     memset(accumulator, 0, sizeof(accumulator));
54*5f39d1b3SJooyung Han 
55*5f39d1b3SJooyung Han     const int run_depth_cells = static_cast<int>(run_depth / Format::kDepth);
56*5f39d1b3SJooyung Han 
57*5f39d1b3SJooyung Han     // The outer loop is over the depth dimension.
58*5f39d1b3SJooyung Han     for (int dc = 0; dc < run_depth_cells; dc++) {
59*5f39d1b3SJooyung Han       // The next two loops are over cells of the Lhs (stacked vertically),
60*5f39d1b3SJooyung Han       // and over cells of the Rhs (stacked horizontally).
61*5f39d1b3SJooyung Han       for (int rc = 0; rc < Format::Lhs::kCells; rc++) {
62*5f39d1b3SJooyung Han         const std::uint8_t* lhs_cell_ptr =
63*5f39d1b3SJooyung Han             lhs_ptr + (dc * Format::Lhs::kCells + rc) *
64*5f39d1b3SJooyung Han                           Format::Lhs::Cell::kWidth * Format::kDepth;
65*5f39d1b3SJooyung Han         for (int cc = 0; cc < Format::Rhs::kCells; cc++) {
66*5f39d1b3SJooyung Han           const std::uint8_t* rhs_cell_ptr =
67*5f39d1b3SJooyung Han               rhs_ptr + (dc * Format::Rhs::kCells + cc) *
68*5f39d1b3SJooyung Han                             Format::Rhs::Cell::kWidth * Format::kDepth;
69*5f39d1b3SJooyung Han 
70*5f39d1b3SJooyung Han           // Now we are inside one cell of the Lhs and inside one cell
71*5f39d1b3SJooyung Han           // of the Rhs, so the remaining inner loops are just
72*5f39d1b3SJooyung Han           // traditional three loops of matrix multiplication.
73*5f39d1b3SJooyung Han           for (int di = 0; di < Format::kDepth; di++) {
74*5f39d1b3SJooyung Han             for (int ri = 0; ri < Format::Lhs::Cell::kWidth; ri++) {
75*5f39d1b3SJooyung Han               for (int ci = 0; ci < Format::Rhs::Cell::kWidth; ci++) {
76*5f39d1b3SJooyung Han                 const std::uint8_t* lhs_coeff_ptr =
77*5f39d1b3SJooyung Han                     lhs_cell_ptr +
78*5f39d1b3SJooyung Han                     OffsetIntoCell<typename Format::Lhs::Cell>(ri, di);
79*5f39d1b3SJooyung Han                 const std::uint8_t* rhs_coeff_ptr =
80*5f39d1b3SJooyung Han                     rhs_cell_ptr +
81*5f39d1b3SJooyung Han                     OffsetIntoCell<typename Format::Rhs::Cell>(ci, di);
82*5f39d1b3SJooyung Han                 std::int32_t* accumulator_coeff_ptr =
83*5f39d1b3SJooyung Han                     accumulator + (ri + rc * Format::Lhs::Cell::kWidth) +
84*5f39d1b3SJooyung Han                     (ci + cc * Format::Rhs::Cell::kWidth) * Format::kRows;
85*5f39d1b3SJooyung Han                 *accumulator_coeff_ptr +=
86*5f39d1b3SJooyung Han                     std::int32_t(*lhs_coeff_ptr) * std::int32_t(*rhs_coeff_ptr);
87*5f39d1b3SJooyung Han               }
88*5f39d1b3SJooyung Han             }
89*5f39d1b3SJooyung Han           }
90*5f39d1b3SJooyung Han         }
91*5f39d1b3SJooyung Han       }
92*5f39d1b3SJooyung Han     }
93*5f39d1b3SJooyung Han 
94*5f39d1b3SJooyung Han     if (start_depth == 0) {
95*5f39d1b3SJooyung Han       // start_depth == 0 means we haven't accumulated anything yet, so we need
96*5f39d1b3SJooyung Han       // to overwrite the accumulator, as it hasn't been initialized to zero.
97*5f39d1b3SJooyung Han       for (int r = 0; r < Format::kRows; r++) {
98*5f39d1b3SJooyung Han         for (int c = 0; c < Format::kCols; c++) {
99*5f39d1b3SJooyung Han           dst_ptr[r * dst_row_stride + c * dst_col_stride] =
100*5f39d1b3SJooyung Han               accumulator[r + c * Format::kRows];
101*5f39d1b3SJooyung Han         }
102*5f39d1b3SJooyung Han       }
103*5f39d1b3SJooyung Han     } else {
104*5f39d1b3SJooyung Han       // We have already accumulated stuff, so we need to continue accumulating
105*5f39d1b3SJooyung Han       // instead of just overwriting.
106*5f39d1b3SJooyung Han       for (int r = 0; r < Format::kRows; r++) {
107*5f39d1b3SJooyung Han         for (int c = 0; c < Format::kCols; c++) {
108*5f39d1b3SJooyung Han           dst_ptr[r * dst_row_stride + c * dst_col_stride] +=
109*5f39d1b3SJooyung Han               accumulator[r + c * Format::kRows];
110*5f39d1b3SJooyung Han         }
111*5f39d1b3SJooyung Han       }
112*5f39d1b3SJooyung Han     }
113*5f39d1b3SJooyung Han   }
114*5f39d1b3SJooyung Han };
115*5f39d1b3SJooyung Han 
116*5f39d1b3SJooyung Han }  // namespace gemmlowp
117*5f39d1b3SJooyung Han 
118*5f39d1b3SJooyung Han #endif  // GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_
119