1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2*5f39d1b3SJooyung Han // 3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License"); 4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License. 5*5f39d1b3SJooyung Han // You may obtain a copy of the License at 6*5f39d1b3SJooyung Han // 7*5f39d1b3SJooyung Han // http://www.apache.org/licenses/LICENSE-2.0 8*5f39d1b3SJooyung Han // 9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software 10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS, 11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and 13*5f39d1b3SJooyung Han // limitations under the License. 14*5f39d1b3SJooyung Han 15*5f39d1b3SJooyung Han // kernel_reference.h: a reference kernel for CPU architectures where we don't 16*5f39d1b3SJooyung Han // have optimized kernels yet. Also useful for testing, as it's templatized 17*5f39d1b3SJooyung Han // to have any arbitrary format, allowing tests to cover all sorts of corner 18*5f39d1b3SJooyung Han // cases. 19*5f39d1b3SJooyung Han 20*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ 21*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ 22*5f39d1b3SJooyung Han 23*5f39d1b3SJooyung Han #include "kernel.h" 24*5f39d1b3SJooyung Han 25*5f39d1b3SJooyung Han #include <cstdio> 26*5f39d1b3SJooyung Han #include <cstring> 27*5f39d1b3SJooyung Han 28*5f39d1b3SJooyung Han namespace gemmlowp { 29*5f39d1b3SJooyung Han 30*5f39d1b3SJooyung Han // This kernel is templatized in an arbitrary Format template parameter, 31*5f39d1b3SJooyung Han // allowing it to have any arbitrary format. 32*5f39d1b3SJooyung Han template <typename tFormat> 33*5f39d1b3SJooyung Han struct ReferenceKernel : KernelBase { 34*5f39d1b3SJooyung Han typedef tFormat Format; 35*5f39d1b3SJooyung Han NameReferenceKernel36*5f39d1b3SJooyung Han const char* Name() const override { 37*5f39d1b3SJooyung Han static char buf[256]; 38*5f39d1b3SJooyung Han snprintf(buf, sizeof(buf), 39*5f39d1b3SJooyung Han "reference(Lhs: %d cells %dx%d %s, Rhs: %d cells %dx%d %s)", 40*5f39d1b3SJooyung Han Format::Lhs::kCells, Format::Lhs::Cell::kWidth, 41*5f39d1b3SJooyung Han Format::Lhs::Cell::kDepth, 42*5f39d1b3SJooyung Han CellOrderName(Format::Lhs::Cell::kOrder), Format::Rhs::kCells, 43*5f39d1b3SJooyung Han Format::Rhs::Cell::kDepth, Format::Rhs::Cell::kWidth, 44*5f39d1b3SJooyung Han CellOrderName(Format::Rhs::Cell::kOrder)); 45*5f39d1b3SJooyung Han return buf; 46*5f39d1b3SJooyung Han } 47*5f39d1b3SJooyung Han RunReferenceKernel48*5f39d1b3SJooyung Han void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, 49*5f39d1b3SJooyung Han std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, 50*5f39d1b3SJooyung Han const std::uint8_t* rhs_ptr, std::size_t start_depth, 51*5f39d1b3SJooyung Han std::size_t run_depth) const override { 52*5f39d1b3SJooyung Han std::int32_t accumulator[Format::kRows * Format::kCols]; 53*5f39d1b3SJooyung Han memset(accumulator, 0, sizeof(accumulator)); 54*5f39d1b3SJooyung Han 55*5f39d1b3SJooyung Han const int run_depth_cells = static_cast<int>(run_depth / Format::kDepth); 56*5f39d1b3SJooyung Han 57*5f39d1b3SJooyung Han // The outer loop is over the depth dimension. 58*5f39d1b3SJooyung Han for (int dc = 0; dc < run_depth_cells; dc++) { 59*5f39d1b3SJooyung Han // The next two loops are over cells of the Lhs (stacked vertically), 60*5f39d1b3SJooyung Han // and over cells of the Rhs (stacked horizontally). 61*5f39d1b3SJooyung Han for (int rc = 0; rc < Format::Lhs::kCells; rc++) { 62*5f39d1b3SJooyung Han const std::uint8_t* lhs_cell_ptr = 63*5f39d1b3SJooyung Han lhs_ptr + (dc * Format::Lhs::kCells + rc) * 64*5f39d1b3SJooyung Han Format::Lhs::Cell::kWidth * Format::kDepth; 65*5f39d1b3SJooyung Han for (int cc = 0; cc < Format::Rhs::kCells; cc++) { 66*5f39d1b3SJooyung Han const std::uint8_t* rhs_cell_ptr = 67*5f39d1b3SJooyung Han rhs_ptr + (dc * Format::Rhs::kCells + cc) * 68*5f39d1b3SJooyung Han Format::Rhs::Cell::kWidth * Format::kDepth; 69*5f39d1b3SJooyung Han 70*5f39d1b3SJooyung Han // Now we are inside one cell of the Lhs and inside one cell 71*5f39d1b3SJooyung Han // of the Rhs, so the remaining inner loops are just 72*5f39d1b3SJooyung Han // traditional three loops of matrix multiplication. 73*5f39d1b3SJooyung Han for (int di = 0; di < Format::kDepth; di++) { 74*5f39d1b3SJooyung Han for (int ri = 0; ri < Format::Lhs::Cell::kWidth; ri++) { 75*5f39d1b3SJooyung Han for (int ci = 0; ci < Format::Rhs::Cell::kWidth; ci++) { 76*5f39d1b3SJooyung Han const std::uint8_t* lhs_coeff_ptr = 77*5f39d1b3SJooyung Han lhs_cell_ptr + 78*5f39d1b3SJooyung Han OffsetIntoCell<typename Format::Lhs::Cell>(ri, di); 79*5f39d1b3SJooyung Han const std::uint8_t* rhs_coeff_ptr = 80*5f39d1b3SJooyung Han rhs_cell_ptr + 81*5f39d1b3SJooyung Han OffsetIntoCell<typename Format::Rhs::Cell>(ci, di); 82*5f39d1b3SJooyung Han std::int32_t* accumulator_coeff_ptr = 83*5f39d1b3SJooyung Han accumulator + (ri + rc * Format::Lhs::Cell::kWidth) + 84*5f39d1b3SJooyung Han (ci + cc * Format::Rhs::Cell::kWidth) * Format::kRows; 85*5f39d1b3SJooyung Han *accumulator_coeff_ptr += 86*5f39d1b3SJooyung Han std::int32_t(*lhs_coeff_ptr) * std::int32_t(*rhs_coeff_ptr); 87*5f39d1b3SJooyung Han } 88*5f39d1b3SJooyung Han } 89*5f39d1b3SJooyung Han } 90*5f39d1b3SJooyung Han } 91*5f39d1b3SJooyung Han } 92*5f39d1b3SJooyung Han } 93*5f39d1b3SJooyung Han 94*5f39d1b3SJooyung Han if (start_depth == 0) { 95*5f39d1b3SJooyung Han // start_depth == 0 means we haven't accumulated anything yet, so we need 96*5f39d1b3SJooyung Han // to overwrite the accumulator, as it hasn't been initialized to zero. 97*5f39d1b3SJooyung Han for (int r = 0; r < Format::kRows; r++) { 98*5f39d1b3SJooyung Han for (int c = 0; c < Format::kCols; c++) { 99*5f39d1b3SJooyung Han dst_ptr[r * dst_row_stride + c * dst_col_stride] = 100*5f39d1b3SJooyung Han accumulator[r + c * Format::kRows]; 101*5f39d1b3SJooyung Han } 102*5f39d1b3SJooyung Han } 103*5f39d1b3SJooyung Han } else { 104*5f39d1b3SJooyung Han // We have already accumulated stuff, so we need to continue accumulating 105*5f39d1b3SJooyung Han // instead of just overwriting. 106*5f39d1b3SJooyung Han for (int r = 0; r < Format::kRows; r++) { 107*5f39d1b3SJooyung Han for (int c = 0; c < Format::kCols; c++) { 108*5f39d1b3SJooyung Han dst_ptr[r * dst_row_stride + c * dst_col_stride] += 109*5f39d1b3SJooyung Han accumulator[r + c * Format::kRows]; 110*5f39d1b3SJooyung Han } 111*5f39d1b3SJooyung Han } 112*5f39d1b3SJooyung Han } 113*5f39d1b3SJooyung Han } 114*5f39d1b3SJooyung Han }; 115*5f39d1b3SJooyung Han 116*5f39d1b3SJooyung Han } // namespace gemmlowp 117*5f39d1b3SJooyung Han 118*5f39d1b3SJooyung Han #endif // GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ 119