xref: /aosp_15_r20/external/gemmlowp/internal/single_thread_gemm.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han // single_thread_gemm.h: Single-threaded GEMM implementation.
16*5f39d1b3SJooyung Han // This is a good place to start reading code, as it shows the overall
17*5f39d1b3SJooyung Han // structure of a GEMM and is much simpler than multi_thread_gemm.h.
18*5f39d1b3SJooyung Han 
19*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_SINGLE_THREAD_GEMM_H_
20*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_SINGLE_THREAD_GEMM_H_
21*5f39d1b3SJooyung Han 
22*5f39d1b3SJooyung Han #include <cassert>
23*5f39d1b3SJooyung Han 
24*5f39d1b3SJooyung Han #include "../public/map.h"
25*5f39d1b3SJooyung Han #include "allocator.h"
26*5f39d1b3SJooyung Han #include "compute.h"
27*5f39d1b3SJooyung Han #include "kernel.h"
28*5f39d1b3SJooyung Han #include "pack.h"
29*5f39d1b3SJooyung Han #include "unpack.h"
30*5f39d1b3SJooyung Han 
31*5f39d1b3SJooyung Han #ifdef GEMMLOWP_PROFILING_SIZES
32*5f39d1b3SJooyung Han #ifndef GEMMLOWP_PROFILING
33*5f39d1b3SJooyung Han #error GEMMLOWP_PROFILING_SIZES without GEMMLOWP_PROFILING
34*5f39d1b3SJooyung Han #endif
35*5f39d1b3SJooyung Han #include <string>
36*5f39d1b3SJooyung Han #include <unordered_map>
37*5f39d1b3SJooyung Han #endif
38*5f39d1b3SJooyung Han 
39*5f39d1b3SJooyung Han namespace gemmlowp {
40*5f39d1b3SJooyung Han 
41*5f39d1b3SJooyung Han class SingleThreadGemmContext {
42*5f39d1b3SJooyung Han  public:
allocator()43*5f39d1b3SJooyung Han   Allocator* allocator() { return &allocator_; }
44*5f39d1b3SJooyung Han 
set_l1_bytes_to_use(int n)45*5f39d1b3SJooyung Han   void set_l1_bytes_to_use(int n) { l1_bytes_to_use_ = n; }
set_l2_bytes_to_use(int n)46*5f39d1b3SJooyung Han   void set_l2_bytes_to_use(int n) { l2_bytes_to_use_ = n; }
set_l2_rhs_factor(float n)47*5f39d1b3SJooyung Han   void set_l2_rhs_factor(float n) { l2_rhs_factor_ = n; }
48*5f39d1b3SJooyung Han 
l1_bytes_to_use()49*5f39d1b3SJooyung Han   int l1_bytes_to_use() const { return l1_bytes_to_use_; }
l2_bytes_to_use()50*5f39d1b3SJooyung Han   int l2_bytes_to_use() const { return l2_bytes_to_use_; }
l2_rhs_factor()51*5f39d1b3SJooyung Han   float l2_rhs_factor() const { return l2_rhs_factor_; }
52*5f39d1b3SJooyung Han 
53*5f39d1b3SJooyung Han  protected:
54*5f39d1b3SJooyung Han   Allocator allocator_;
55*5f39d1b3SJooyung Han 
56*5f39d1b3SJooyung Han   // The cache configurationt to use.
57*5f39d1b3SJooyung Han   int l1_bytes_to_use_ = kDefaultL1CacheSize;
58*5f39d1b3SJooyung Han   int l2_bytes_to_use_ = kDefaultL2CacheSize;
59*5f39d1b3SJooyung Han   float l2_rhs_factor_ = kDefaultL2RhsFactor;
60*5f39d1b3SJooyung Han };
61*5f39d1b3SJooyung Han 
62*5f39d1b3SJooyung Han template <typename KernelFormat, typename InputScalar, typename OutputScalar,
63*5f39d1b3SJooyung Han           typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder,
64*5f39d1b3SJooyung Han           MapOrder ResultOrder, typename LhsOffset, typename RhsOffset,
65*5f39d1b3SJooyung Han           typename OutputPipelineType>
SingleThreadGemm(SingleThreadGemmContext * context,const KernelBase & kernel,const MatrixMap<const InputScalar,LhsOrder> & lhs,const MatrixMap<const InputScalar,RhsOrder> & rhs,MatrixMap<OutputScalar,ResultOrder> * result,const LhsOffset & lhs_offset,const RhsOffset & rhs_offset,const OutputPipelineType & output_pipeline)66*5f39d1b3SJooyung Han void SingleThreadGemm(SingleThreadGemmContext* context,
67*5f39d1b3SJooyung Han                       const KernelBase& kernel,
68*5f39d1b3SJooyung Han                       const MatrixMap<const InputScalar, LhsOrder>& lhs,
69*5f39d1b3SJooyung Han                       const MatrixMap<const InputScalar, RhsOrder>& rhs,
70*5f39d1b3SJooyung Han                       MatrixMap<OutputScalar, ResultOrder>* result,
71*5f39d1b3SJooyung Han                       const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
72*5f39d1b3SJooyung Han                       const OutputPipelineType& output_pipeline) {
73*5f39d1b3SJooyung Han   ScopedProfilingLabel label("gemmlowp::SingleThreadGemm");
74*5f39d1b3SJooyung Han 
75*5f39d1b3SJooyung Han   assert(lhs.cols() == rhs.rows());
76*5f39d1b3SJooyung Han 
77*5f39d1b3SJooyung Han   int rows = result->rows();
78*5f39d1b3SJooyung Han   int cols = result->cols();
79*5f39d1b3SJooyung Han   int depth = lhs.cols();
80*5f39d1b3SJooyung Han 
81*5f39d1b3SJooyung Han   // zero sizes should have been caught earlier and early-returned.
82*5f39d1b3SJooyung Han   assert(rows > 0);
83*5f39d1b3SJooyung Han   assert(cols > 0);
84*5f39d1b3SJooyung Han   assert(depth > 0);
85*5f39d1b3SJooyung Han 
86*5f39d1b3SJooyung Han   // The case of rows<cols should have been caught earlier and transposed.
87*5f39d1b3SJooyung Han   assert(rows >= cols);
88*5f39d1b3SJooyung Han 
89*5f39d1b3SJooyung Han   Allocator* allocator = context->allocator();
90*5f39d1b3SJooyung Han 
91*5f39d1b3SJooyung Han   BlockParams block_params;
92*5f39d1b3SJooyung Han   block_params.Init<KernelFormat>(
93*5f39d1b3SJooyung Han       rows, cols, depth, 1, context->l1_bytes_to_use(),
94*5f39d1b3SJooyung Han       context->l2_bytes_to_use(), context->l2_rhs_factor());
95*5f39d1b3SJooyung Han 
96*5f39d1b3SJooyung Han #ifdef GEMMLOWP_PROFILING_SIZES
97*5f39d1b3SJooyung Han   // Using a static map of label strings. Not reentrant at all!
98*5f39d1b3SJooyung Han   static std::unordered_map<std::uint64_t, std::string> labels_map;
99*5f39d1b3SJooyung Han   std::uint64_t sizes_hash = static_cast<std::uint64_t>(rows) ^
100*5f39d1b3SJooyung Han                              (static_cast<std::uint64_t>(depth) << 16) ^
101*5f39d1b3SJooyung Han                              (static_cast<std::uint64_t>(cols) << 32);
102*5f39d1b3SJooyung Han   if (!labels_map.count(sizes_hash)) {
103*5f39d1b3SJooyung Han     char label[256];
104*5f39d1b3SJooyung Han     snprintf(label, sizeof(label),
105*5f39d1b3SJooyung Han              "(rows = %d, depth = %d, cols = %d, l2_rows = %d, l2_depth = %d, "
106*5f39d1b3SJooyung Han              "l2_cols = %d, l1_rows = %d, l1_depth = %d, l1_cols = %d)",
107*5f39d1b3SJooyung Han              rows, depth, cols, block_params.l2_rows, block_params.l2_depth,
108*5f39d1b3SJooyung Han              block_params.l2_cols, block_params.l1_rows, block_params.l1_depth,
109*5f39d1b3SJooyung Han              block_params.l1_cols);
110*5f39d1b3SJooyung Han     labels_map[sizes_hash] = label;
111*5f39d1b3SJooyung Han   }
112*5f39d1b3SJooyung Han   ScopedProfilingLabel size_label(labels_map[sizes_hash].c_str());
113*5f39d1b3SJooyung Han #endif
114*5f39d1b3SJooyung Han 
115*5f39d1b3SJooyung Han   PackedSideBlock<typename KernelFormat::Lhs> packed_lhs(Side::Lhs, allocator,
116*5f39d1b3SJooyung Han                                                          block_params);
117*5f39d1b3SJooyung Han   PackedSideBlock<typename KernelFormat::Rhs> packed_rhs(Side::Rhs, allocator,
118*5f39d1b3SJooyung Han                                                          block_params);
119*5f39d1b3SJooyung Han 
120*5f39d1b3SJooyung Han   PackedResult packed_result(allocator, block_params);
121*5f39d1b3SJooyung Han 
122*5f39d1b3SJooyung Han   allocator->Commit();
123*5f39d1b3SJooyung Han 
124*5f39d1b3SJooyung Han   const bool pack_rhs_once = block_params.l2_cols >= cols;
125*5f39d1b3SJooyung Han 
126*5f39d1b3SJooyung Han   if (pack_rhs_once) {
127*5f39d1b3SJooyung Han     PackRhs(&packed_rhs, rhs);
128*5f39d1b3SJooyung Han   }
129*5f39d1b3SJooyung Han 
130*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r += block_params.l2_rows) {
131*5f39d1b3SJooyung Han     int rs = std::min(block_params.l2_rows, rows - r);
132*5f39d1b3SJooyung Han 
133*5f39d1b3SJooyung Han     PackLhs(&packed_lhs, lhs.block(r, 0, rs, depth));
134*5f39d1b3SJooyung Han 
135*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c += block_params.l2_cols) {
136*5f39d1b3SJooyung Han       int cs = std::min(block_params.l2_cols, cols - c);
137*5f39d1b3SJooyung Han 
138*5f39d1b3SJooyung Han       if (!pack_rhs_once) {
139*5f39d1b3SJooyung Han         PackRhs(&packed_rhs, rhs.block(0, c, depth, cs));
140*5f39d1b3SJooyung Han       }
141*5f39d1b3SJooyung Han 
142*5f39d1b3SJooyung Han       Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs,
143*5f39d1b3SJooyung Han               depth);
144*5f39d1b3SJooyung Han 
145*5f39d1b3SJooyung Han       UnpackResult<KernelFormat>(
146*5f39d1b3SJooyung Han           result, MatrixBlockBounds(r, c, rs, cs), packed_result, depth,
147*5f39d1b3SJooyung Han           packed_lhs.sums_of_each_slice(), packed_rhs.sums_of_each_slice(),
148*5f39d1b3SJooyung Han           lhs_offset.block(r, rs), rhs_offset.block(c, cs), output_pipeline);
149*5f39d1b3SJooyung Han     }
150*5f39d1b3SJooyung Han   }
151*5f39d1b3SJooyung Han 
152*5f39d1b3SJooyung Han   allocator->Decommit();
153*5f39d1b3SJooyung Han }
154*5f39d1b3SJooyung Han 
155*5f39d1b3SJooyung Han }  // namespace gemmlowp
156*5f39d1b3SJooyung Han 
157*5f39d1b3SJooyung Han #endif  // GEMMLOWP_INTERNAL_SINGLE_THREAD_GEMM_H_
158