1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han // http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han
15*5f39d1b3SJooyung Han // kernel.h: general definitions for kernels.
16*5f39d1b3SJooyung Han
17*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_KERNEL_H_
18*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_KERNEL_H_
19*5f39d1b3SJooyung Han
20*5f39d1b3SJooyung Han #include "../public/bit_depth.h"
21*5f39d1b3SJooyung Han #include "common.h"
22*5f39d1b3SJooyung Han
23*5f39d1b3SJooyung Han namespace gemmlowp {
24*5f39d1b3SJooyung Han
25*5f39d1b3SJooyung Han // Explanation of general gemmlowp terminology
26*5f39d1b3SJooyung Han // ===========================================
27*5f39d1b3SJooyung Han //
28*5f39d1b3SJooyung Han // We use the following abbreviations:
29*5f39d1b3SJooyung Han // LHS = "left-hand side"
30*5f39d1b3SJooyung Han // RHS = "right-hand side"
31*5f39d1b3SJooyung Han // Sometimes when referring to either LHS or RHS, we just say a "Side".
32*5f39d1b3SJooyung Han //
33*5f39d1b3SJooyung Han // In a matrix product of a MxK matrix times a KxN matrix,
34*5f39d1b3SJooyung Han // we call K the 'depth'. Note that M is the number of rows
35*5f39d1b3SJooyung Han // of the result (and of the LHS), and N is the number of columns
36*5f39d1b3SJooyung Han // of the result (and of the RHS).
37*5f39d1b3SJooyung Han //
38*5f39d1b3SJooyung Han // In each of the LHS and RHS matrices, we call 'width' the
39*5f39d1b3SJooyung Han // other dimension, besides the depth. So in the LHS, 'width'
40*5f39d1b3SJooyung Han // is the number of rows, while in the RHS, 'width' is the number
41*5f39d1b3SJooyung Han // of columns.
42*5f39d1b3SJooyung Han //
43*5f39d1b3SJooyung Han // So in the LHS MxK matrix, the depth is K and the width in M.
44*5f39d1b3SJooyung Han // And in the RHS KxN matrix, the depth is K and the width in N.
45*5f39d1b3SJooyung Han //
46*5f39d1b3SJooyung Han // This is illustrated in this picture:
47*5f39d1b3SJooyung Han //
48*5f39d1b3SJooyung Han // RHS width
49*5f39d1b3SJooyung Han // <----------------->
50*5f39d1b3SJooyung Han // +-----------------+ ^
51*5f39d1b3SJooyung Han // | RHS | | Depth
52*5f39d1b3SJooyung Han // +-----------------+ v
53*5f39d1b3SJooyung Han // ^ +--+ +-----------------+
54*5f39d1b3SJooyung Han // | |L | | |
55*5f39d1b3SJooyung Han // LHS width | |H | | Result |
56*5f39d1b3SJooyung Han // | |S | | |
57*5f39d1b3SJooyung Han // v +--+ +-----------------+
58*5f39d1b3SJooyung Han // <-->
59*5f39d1b3SJooyung Han // Depth
60*5f39d1b3SJooyung Han
61*5f39d1b3SJooyung Han // Explanation of gemmlowp kernel formats and "cells"
62*5f39d1b3SJooyung Han // ==================================================
63*5f39d1b3SJooyung Han //
64*5f39d1b3SJooyung Han // Kernels operate on small LHS and RHS blocks that fit in registers.
65*5f39d1b3SJooyung Han // These blocks are stored contiguously in memory, but not always
66*5f39d1b3SJooyung Han // in a traditional column-major or row-major order; instead,
67*5f39d1b3SJooyung Han // they consist of a number of sub-blocks, which we call "cells",
68*5f39d1b3SJooyung Han // that are stored in column-major or row-major order. However,
69*5f39d1b3SJooyung Han // what really matters to us is not so much rows vs columns, but
70*5f39d1b3SJooyung Han // rather width vs depth. So we refer to "width-major" and "depth-major"
71*5f39d1b3SJooyung Han // storage orders. In the LHS, width-major means row-major,
72*5f39d1b3SJooyung Han // while in the RHS, width-major means column-major.
73*5f39d1b3SJooyung Han // There is also a third possibility, "diagonal order",
74*5f39d1b3SJooyung Han // which is unused at the moment.
75*5f39d1b3SJooyung Han //
76*5f39d1b3SJooyung Han // We aim to treat both sides, LHS and RHS, on an equal footing,
77*5f39d1b3SJooyung Han // so we call them both 'sides'. A KernelFormat thus is just a pair
78*5f39d1b3SJooyung Han // of KernelSideFormat's, one for LHS and one for RHS; each KernelSideFormat
79*5f39d1b3SJooyung Han // contains a CellFormat and a number of cells; cells are only ever
80*5f39d1b3SJooyung Han // stacked in the width dimension, which means stacked vertically in the
81*5f39d1b3SJooyung Han // LHS and stacked horizondally in the RHS.
82*5f39d1b3SJooyung Han //
83*5f39d1b3SJooyung Han // Example
84*5f39d1b3SJooyung Han // =======
85*5f39d1b3SJooyung Han //
86*5f39d1b3SJooyung Han // Let's work out the data layout expected by a kernel having the
87*5f39d1b3SJooyung Han // following format (the struct names here are defined below in this file):
88*5f39d1b3SJooyung Han //
89*5f39d1b3SJooyung Han // KernelFormat<
90*5f39d1b3SJooyung Han // KernelSideFormat<CellFormat<3, 4>, 3>,
91*5f39d1b3SJooyung Han // KernelSideFormat<CellFormat<5, 4>, 2>
92*5f39d1b3SJooyung Han // >
93*5f39d1b3SJooyung Han //
94*5f39d1b3SJooyung Han // The LHS format, KernelSideFormat<CellFormat<3, 4>, 3>, means:
95*5f39d1b3SJooyung Han // 3 cells, each cell having dimensions (width=3, depth=4), laid out in
96*5f39d1b3SJooyung Han // DepthMajor order (the default value, see CellFormat). In the LHS,
97*5f39d1b3SJooyung Han // DepthMajor means column-major, so the LHS cells are of size 3x4 in
98*5f39d1b3SJooyung Han // column-major order, so the LHS layout is:
99*5f39d1b3SJooyung Han //
100*5f39d1b3SJooyung Han // 0 3 6 9
101*5f39d1b3SJooyung Han // 1 4 7 10
102*5f39d1b3SJooyung Han // 2 5 8 11
103*5f39d1b3SJooyung Han // 12 15 18 21
104*5f39d1b3SJooyung Han // 13 16 19 22
105*5f39d1b3SJooyung Han // 14 17 20 23
106*5f39d1b3SJooyung Han // 24 27 30 33
107*5f39d1b3SJooyung Han // 25 28 31 34
108*5f39d1b3SJooyung Han // 26 29 32 35
109*5f39d1b3SJooyung Han //
110*5f39d1b3SJooyung Han // The RHS format, KernelSideFormat<CellFormat<5, 4>, 2>, means:
111*5f39d1b3SJooyung Han // 2 cells each having dimensions (width=5, depth=4), laid out in
112*5f39d1b3SJooyung Han // DepthMajor order (the default value, see CellFormat). In the RHS,
113*5f39d1b3SJooyung Han // DepthMajor means row-major, so the RHS cells are of size 4x5 in
114*5f39d1b3SJooyung Han // row-major order, so the RHS layout is:
115*5f39d1b3SJooyung Han //
116*5f39d1b3SJooyung Han // 0 1 2 3 4 20 21 22 23 24
117*5f39d1b3SJooyung Han // 5 6 7 8 9 25 26 27 28 29
118*5f39d1b3SJooyung Han // 10 11 12 13 14 30 31 32 33 34
119*5f39d1b3SJooyung Han // 15 16 17 18 19 35 36 37 38 39
120*5f39d1b3SJooyung Han
121*5f39d1b3SJooyung Han // CellOrder enumerates the possible storage orders (=layouts) for
122*5f39d1b3SJooyung Han // a cell (see explanation above).
123*5f39d1b3SJooyung Han enum class CellOrder { DepthMajor, WidthMajor, Diagonal };
124*5f39d1b3SJooyung Han
125*5f39d1b3SJooyung Han // CellFormat describes how data is laid
126*5f39d1b3SJooyung Han // out in a cell. That is, a CellOrder together with actual dimensions.
127*5f39d1b3SJooyung Han template <int tWidth, int tDepth, CellOrder tOrder = CellOrder::DepthMajor>
128*5f39d1b3SJooyung Han struct CellFormat {
129*5f39d1b3SJooyung Han static constexpr int kWidth = tWidth;
130*5f39d1b3SJooyung Han static constexpr int kDepth = tDepth;
131*5f39d1b3SJooyung Han static constexpr CellOrder kOrder = tOrder;
132*5f39d1b3SJooyung Han
133*5f39d1b3SJooyung Han static constexpr int kSize = kWidth * kDepth;
134*5f39d1b3SJooyung Han };
135*5f39d1b3SJooyung Han
136*5f39d1b3SJooyung Han // KernelSideFormat describes how data is laid out in a kernel side
137*5f39d1b3SJooyung Han // (i.e. LHS or RHS). That is, a CellFormat together with a number of
138*5f39d1b3SJooyung Han // cells. These cells are always stacked in the Width dimension.
139*5f39d1b3SJooyung Han // For example, in the LHS case, the Width dimension is the rows dimension,
140*5f39d1b3SJooyung Han // se we're saying that in the LHS, cells are stacked vertically.
141*5f39d1b3SJooyung Han // We never stack cells in the Depth dimension.
142*5f39d1b3SJooyung Han template <typename tCellFormat, int tCells>
143*5f39d1b3SJooyung Han struct KernelSideFormat {
144*5f39d1b3SJooyung Han typedef tCellFormat Cell;
145*5f39d1b3SJooyung Han static constexpr int kCells = tCells;
146*5f39d1b3SJooyung Han static constexpr int kWidth = kCells * Cell::kWidth;
147*5f39d1b3SJooyung Han static constexpr int kDepth = Cell::kDepth;
148*5f39d1b3SJooyung Han typedef std::uint8_t Scalar; // The scalar type of the Format.
149*5f39d1b3SJooyung Han typedef std::uint8_t InputScalar; // The scalar type of the original input.
150*5f39d1b3SJooyung Han };
151*5f39d1b3SJooyung Han
152*5f39d1b3SJooyung Han // KernelSideFormat for int8 fast kernel trick. The original input is uint8, but
153*5f39d1b3SJooyung Han // packs converts it to int8.
154*5f39d1b3SJooyung Han template <typename tCellFormat, int tCells>
155*5f39d1b3SJooyung Han struct KernelSideFormatInt8 : KernelSideFormat<tCellFormat, tCells> {
156*5f39d1b3SJooyung Han typedef std::int8_t Scalar;
157*5f39d1b3SJooyung Han typedef std::uint8_t InputScalar;
158*5f39d1b3SJooyung Han };
159*5f39d1b3SJooyung Han
160*5f39d1b3SJooyung Han // KernelSideFormat for int8 inputs, enabling int8 fast kernel trick without
161*5f39d1b3SJooyung Han // pack conversion.
162*5f39d1b3SJooyung Han template <typename tCellFormat, int tCells>
163*5f39d1b3SJooyung Han struct KernelSideFormatInt8Inputs : KernelSideFormat<tCellFormat, tCells> {
164*5f39d1b3SJooyung Han typedef std::int8_t Scalar;
165*5f39d1b3SJooyung Han typedef std::int8_t InputScalar;
166*5f39d1b3SJooyung Han };
167*5f39d1b3SJooyung Han
168*5f39d1b3SJooyung Han // KernelFormat describes fully the input data layout that a kernel expects.
169*5f39d1b3SJooyung Han // It consists of two KernelSideFormat's, one for LHS and one for RHS.
170*5f39d1b3SJooyung Han template <typename tLhs, typename tRhs>
171*5f39d1b3SJooyung Han struct KernelFormat {
172*5f39d1b3SJooyung Han typedef tLhs Lhs;
173*5f39d1b3SJooyung Han typedef tRhs Rhs;
174*5f39d1b3SJooyung Han
175*5f39d1b3SJooyung Han static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, "");
176*5f39d1b3SJooyung Han static constexpr int kDepth = Lhs::Cell::kDepth;
177*5f39d1b3SJooyung Han static constexpr int kRows = Lhs::Cell::kWidth * Lhs::kCells;
178*5f39d1b3SJooyung Han static constexpr int kCols = Rhs::Cell::kWidth * Rhs::kCells;
179*5f39d1b3SJooyung Han };
180*5f39d1b3SJooyung Han
CellOrderName(CellOrder o)181*5f39d1b3SJooyung Han inline const char* CellOrderName(CellOrder o) {
182*5f39d1b3SJooyung Han switch (o) {
183*5f39d1b3SJooyung Han case CellOrder::DepthMajor:
184*5f39d1b3SJooyung Han return "DepthMajor";
185*5f39d1b3SJooyung Han case CellOrder::WidthMajor:
186*5f39d1b3SJooyung Han return "WidthMajor";
187*5f39d1b3SJooyung Han case CellOrder::Diagonal:
188*5f39d1b3SJooyung Han return "Diagonal";
189*5f39d1b3SJooyung Han default:
190*5f39d1b3SJooyung Han assert(false);
191*5f39d1b3SJooyung Han return nullptr;
192*5f39d1b3SJooyung Han }
193*5f39d1b3SJooyung Han }
194*5f39d1b3SJooyung Han
195*5f39d1b3SJooyung Han // Returns the offset into a cell, at which a given coefficient is stored.
196*5f39d1b3SJooyung Han template <typename CellFormat>
OffsetIntoCell(int w,int d)197*5f39d1b3SJooyung Han inline int OffsetIntoCell(int w, int d) {
198*5f39d1b3SJooyung Han const int size = CellFormat::kWidth;
199*5f39d1b3SJooyung Han switch (CellFormat::kOrder) {
200*5f39d1b3SJooyung Han case CellOrder::DepthMajor:
201*5f39d1b3SJooyung Han return w + d * CellFormat::kWidth;
202*5f39d1b3SJooyung Han case CellOrder::WidthMajor:
203*5f39d1b3SJooyung Han return d + w * CellFormat::kDepth;
204*5f39d1b3SJooyung Han case CellOrder::Diagonal:
205*5f39d1b3SJooyung Han assert(CellFormat::kWidth == CellFormat::kDepth);
206*5f39d1b3SJooyung Han return ((size + w - d) * size + d) % (size * size);
207*5f39d1b3SJooyung Han default:
208*5f39d1b3SJooyung Han assert(false);
209*5f39d1b3SJooyung Han return 0;
210*5f39d1b3SJooyung Han }
211*5f39d1b3SJooyung Han }
212*5f39d1b3SJooyung Han
213*5f39d1b3SJooyung Han // KernelBase is the virtual base class below all kernels.
214*5f39d1b3SJooyung Han // The idea is that we don't need to templatize all our code on the exact
215*5f39d1b3SJooyung Han // kernel type; we only need to templatize on kernel format. Kernels
216*5f39d1b3SJooyung Han // sharing the same format can thus share the same packing/unpacking code.
217*5f39d1b3SJooyung Han struct KernelBase {
218*5f39d1b3SJooyung Han virtual const char* Name() const = 0;
219*5f39d1b3SJooyung Han
220*5f39d1b3SJooyung Han // This is the kernel implementation. We use the word 'run' consistently
221*5f39d1b3SJooyung Han // throughout gemmlowp to mean an inner loop, the implementation of which
222*5f39d1b3SJooyung Han // is to be provided by a separate optimized function.
223*5f39d1b3SJooyung Han virtual void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
224*5f39d1b3SJooyung Han std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
225*5f39d1b3SJooyung Han const std::uint8_t* rhs_ptr, std::size_t start_depth,
226*5f39d1b3SJooyung Han std::size_t run_depth) const = 0;
227*5f39d1b3SJooyung Han
~KernelBaseKernelBase228*5f39d1b3SJooyung Han virtual ~KernelBase() {}
229*5f39d1b3SJooyung Han };
230*5f39d1b3SJooyung Han
231*5f39d1b3SJooyung Han template <typename InputKernelScalarType, typename KernelScalarType>
232*5f39d1b3SJooyung Han struct ZeroPointInputValue {};
233*5f39d1b3SJooyung Han
234*5f39d1b3SJooyung Han template <>
235*5f39d1b3SJooyung Han struct ZeroPointInputValue<std::uint8_t, std::uint8_t> {
236*5f39d1b3SJooyung Han static constexpr std::uint8_t kValue = 0;
237*5f39d1b3SJooyung Han };
238*5f39d1b3SJooyung Han
239*5f39d1b3SJooyung Han template <>
240*5f39d1b3SJooyung Han struct ZeroPointInputValue<std::uint8_t, std::int8_t> {
241*5f39d1b3SJooyung Han static constexpr std::uint8_t kValue = 128;
242*5f39d1b3SJooyung Han };
243*5f39d1b3SJooyung Han
244*5f39d1b3SJooyung Han template <>
245*5f39d1b3SJooyung Han struct ZeroPointInputValue<std::int8_t, std::int8_t> {
246*5f39d1b3SJooyung Han static constexpr std::uint8_t kValue = 0;
247*5f39d1b3SJooyung Han };
248*5f39d1b3SJooyung Han
249*5f39d1b3SJooyung Han } // namespace gemmlowp
250*5f39d1b3SJooyung Han
251*5f39d1b3SJooyung Han #endif // GEMMLOWP_INTERNAL_KERNEL_H_
252