xref: /aosp_15_r20/external/gemmlowp/internal/dispatch_gemm_shape.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1 // Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // dispatch_gemm_shape.h: dispatch GEMM calls according to their shape
16 
17 #ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
18 #define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
19 
20 #include "../internal/kernel_default.h"
21 #include "../public/map.h"
22 #include "../public/output_stages.h"
23 #include "multi_thread_gemm.h"
24 
25 namespace gemmlowp {
26 
27 template <typename T>
28 struct TransposeImpl {
29   typedef T DstType;
RunTransposeImpl30   static T Run(const T& t) { return t; }
31 };
32 
33 template <typename T>
34 using TransposeType = typename TransposeImpl<T>::DstType;
35 
36 template <typename T>
Transpose(const T & t)37 TransposeType<T> Transpose(const T& t) {
38   return TransposeImpl<T>::Run(t);
39 }
40 
41 template <MapOrder Order>
42 struct TransposeMapOrder {
43   static constexpr MapOrder Value =
44       Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor;
45 };
46 
47 template <VectorShape Shape>
48 struct TransposeVectorShape {
49   static constexpr VectorShape Value =
50       Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row;
51 };
52 
53 template <typename Scalar, VectorShape Shape>
54 struct TransposeImpl<VectorMap<Scalar, Shape>> {
55   typedef VectorMap<Scalar, Shape> SrcType;
56   static constexpr VectorShape TransposedShape =
57       TransposeVectorShape<Shape>::Value;
58   typedef VectorMap<Scalar, TransposedShape> DstType;
59   static DstType Run(const SrcType& src) {
60     return DstType(src.data(), src.size());
61   }
62 };
63 
64 template <typename Scalar, MapOrder Order>
65 struct TransposeImpl<MatrixMap<Scalar, Order>> {
66   typedef MatrixMap<Scalar, Order> SrcType;
67   static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value;
68   typedef MatrixMap<Scalar, TransposedOrder> DstType;
69   static DstType Run(const SrcType& src) {
70     return DstType(src.data(), src.cols(), src.rows(), src.stride());
71   }
72 };
73 
74 template <VectorShape Shape>
75 struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
76   typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType;
77   static constexpr VectorShape TransposedShape =
78       TransposeVectorShape<Shape>::Value;
79   typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType;
80   static DstType Run(const SrcType& src) {
81     DstType dst;
82     dst.result_shift = src.result_shift;
83     dst.result_offset = Transpose(src.result_offset);
84     dst.result_mult_int = Transpose(src.result_mult_int);
85     return dst;
86   }
87 };
88 
89 template <VectorShape Shape>
90 struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> {
91   typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType;
92   static constexpr VectorShape TransposedShape =
93       TransposeVectorShape<Shape>::Value;
94   typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape>
95       DstType;
96   static DstType Run(const SrcType& src) {
97     DstType dst;
98     dst.result_fixedpoint_multiplier =
99         Transpose(src.result_fixedpoint_multiplier);
100     dst.result_exponent = Transpose(src.result_exponent);
101     dst.result_offset_after_shift = src.result_offset_after_shift;
102     return dst;
103   }
104 };
105 
106 template <typename VectorMapType>
107 struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
108   typedef OutputStageBiasAddition<VectorMapType> SrcType;
109   typedef TransposeType<VectorMapType> TransposedVectorMapType;
110   typedef OutputStageBiasAddition<TransposedVectorMapType> DstType;
111   static DstType Run(const SrcType& src) {
112     DstType dst;
113     dst.bias_vector = Transpose(src.bias_vector);
114     return dst;
115   }
116 };
117 
118 // TODO(benoitjacob) - does anyone understand C++ variadic templates?
119 // How to use them to implement TransposeTuple? Note: there are lots
120 // of answers on StackOverflow but they seem to all involve either
121 // C++14/C++17 (we can only use C++11) or lots of abstract nonsense.
122 inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; }
123 
124 template <typename T0>
125 std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) {
126   return std::make_tuple(Transpose(std::get<0>(t)));
127 }
128 
129 template <typename T0, typename T1>
130 std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple(
131     const std::tuple<T0, T1>& t) {
132   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)));
133 }
134 
135 template <typename T0, typename T1, typename T2>
136 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>>
137 TransposeTuple(const std::tuple<T0, T1, T2>& t) {
138   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
139                          Transpose(std::get<2>(t)));
140 }
141 
142 template <typename T0, typename T1, typename T2, typename T3>
143 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
144            TransposeType<T3>>
145 TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) {
146   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
147                          Transpose(std::get<2>(t)), Transpose(std::get<3>(t)));
148 }
149 
150 template <typename T0, typename T1, typename T2, typename T3, typename T4>
151 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
152            TransposeType<T3>, TransposeType<T4>>
153 TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) {
154   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
155                          Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
156                          Transpose(std::get<4>(t)));
157 }
158 
159 template <typename T0, typename T1, typename T2, typename T3, typename T4,
160           typename T5>
161 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
162            TransposeType<T3>, TransposeType<T4>, TransposeType<T5>>
163 TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) {
164   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
165                          Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
166                          Transpose(std::get<4>(t)), Transpose(std::get<5>(t)));
167 }
168 
169 template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
170           MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
171           typename LhsOffset, typename RhsOffset, typename OutputPipelineType,
172           typename GemmContextType>
173 void DispatchGemmShape(GemmContextType* context,
174                        const MatrixMap<const InputScalar, LhsOrder>& lhs,
175                        const MatrixMap<const InputScalar, RhsOrder>& rhs,
176                        MatrixMap<OutputScalar, ResultOrder>* result,
177                        const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
178                        const OutputPipelineType& output_pipeline) {
179   assert(lhs.cols() == rhs.rows());
180 
181   int rows = result->rows();
182   int cols = result->cols();
183   int depth = lhs.cols();
184 
185   if (rows == 0 || cols == 0 || depth == 0) {
186     // Vacuous GEMM, return early to avoid having to deal with
187     // zero sizes below.
188     return;
189   }
190 
191   if (rows < cols) {
192     auto transposed_result_map = Transpose(*result);
193     return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
194         context, Transpose(rhs), Transpose(lhs), &transposed_result_map,
195         Transpose(rhs_offset), Transpose(lhs_offset),
196         TransposeTuple(output_pipeline));
197   }
198 
199   typedef DefaultKernel<BitDepthParams> Kernel;
200   MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
201                   BitDepthParams>(context, Kernel(), lhs, rhs, result,
202                                   lhs_offset, rhs_offset, output_pipeline);
203 }
204 
205 }  // end namespace gemmlowp
206 
207 #endif  // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
208