1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <iostream>
18
19 #include "ruy/ruy.h"
20
ExampleMulFloat(ruy::Context * context)21 void ExampleMulFloat(ruy::Context *context) {
22 const float lhs_data[] = {1, 2, 3, 4};
23 const float rhs_data[] = {1, 2, 3, 4};
24 float dst_data[4];
25
26 ruy::Matrix<float> lhs;
27 ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
28 lhs.set_data(lhs_data);
29 ruy::Matrix<float> rhs;
30 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
31 rhs.set_data(rhs_data);
32 ruy::Matrix<float> dst;
33 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
34 dst.set_data(dst_data);
35
36 ruy::MulParams<float, float> mul_params;
37 ruy::Mul(lhs, rhs, mul_params, context, &dst);
38
39 std::cout << "Example Mul, float:\n";
40 std::cout << "LHS:\n" << lhs;
41 std::cout << "RHS:\n" << rhs;
42 std::cout << "Result:\n" << dst << "\n";
43 }
44
ExampleMulFloatWithBiasAddAndClamp(ruy::Context * context)45 void ExampleMulFloatWithBiasAddAndClamp(ruy::Context *context) {
46 const float lhs_data[] = {1, 2, 3, 4};
47 const float rhs_data[] = {1, 2, 3, 4};
48 const float bias_data[] = {1, 0};
49 float dst_data[4];
50
51 ruy::Matrix<float> lhs;
52 ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
53 lhs.set_data(lhs_data);
54 ruy::Matrix<float> rhs;
55 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
56 rhs.set_data(rhs_data);
57 ruy::Matrix<float> dst;
58 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
59 dst.set_data(dst_data);
60
61 ruy::MulParams<float, float> mul_params;
62 mul_params.set_bias(bias_data);
63 mul_params.set_clamp_min(0);
64 mul_params.set_clamp_max(15);
65 ruy::Mul(lhs, rhs, mul_params, context, &dst);
66
67 std::cout << "Example Mul, float with bias addition and clamp:\n";
68 std::cout << "LHS:\n" << lhs;
69 std::cout << "RHS:\n" << rhs;
70 std::cout << "Result:\n" << dst << "\n";
71 }
72
ExampleMulUint8AsymmetricQuantized(ruy::Context * context)73 void ExampleMulUint8AsymmetricQuantized(ruy::Context *context) {
74 const std::uint8_t lhs_data[] = {124, 125, 126, 127};
75 const std::uint8_t rhs_data[] = {129, 130, 131, 132};
76 std::uint8_t dst_data[4];
77
78 ruy::Matrix<std::uint8_t> lhs;
79 ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
80 lhs.set_data(lhs_data);
81 lhs.set_zero_point(125);
82 ruy::Matrix<std::uint8_t> rhs;
83 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
84 rhs.set_data(rhs_data);
85 rhs.set_zero_point(132);
86 ruy::Matrix<std::uint8_t> dst;
87 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
88 dst.set_data(dst_data);
89 dst.set_zero_point(129);
90
91 ruy::MulParams<std::int32_t, std::uint8_t> mul_params;
92 mul_params.set_multiplier_fixedpoint(1 << 30);
93
94 mul_params.set_multiplier_exponent(0);
95 ruy::Mul(lhs, rhs, mul_params, context, &dst);
96
97 std::cout << "Example Mul, uint8 quantized with asymmetric zero points:\n";
98 std::cout << "LHS:\n" << lhs;
99 std::cout << "RHS:\n" << rhs;
100 std::cout << "Result:\n" << dst << "\n";
101 }
ExampleMulInt8PerChannelQuantized(ruy::Context * context)102 void ExampleMulInt8PerChannelQuantized(ruy::Context *context) {
103 const std::int8_t lhs_data[] = {1, 2, 3, 4};
104 const std::int8_t rhs_data[] = {1, 2, 3, 4};
105 const std::int32_t multiplier_data[] = {3 << 28, 5 << 28};
106 const int exponent_data[] = {1, -2};
107 std::int8_t dst_data[4];
108
109 ruy::Matrix<std::int8_t> lhs;
110 ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
111 lhs.set_data(lhs_data);
112 ruy::Matrix<std::int8_t> rhs;
113 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
114 rhs.set_data(rhs_data);
115 ruy::Matrix<std::int8_t> dst;
116 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
117 dst.set_data(dst_data);
118
119 ruy::MulParams<std::int32_t, std::int8_t> mul_params;
120 mul_params.set_multiplier_fixedpoint_perchannel(multiplier_data);
121 mul_params.set_multiplier_exponent_perchannel(exponent_data);
122 ruy::Mul(lhs, rhs, mul_params, context, &dst);
123
124 std::cout << "Example Mul, int8 quantized with per-channel multipliers\n";
125 std::cout << "LHS:\n" << lhs;
126 std::cout << "RHS:\n" << rhs;
127 std::cout << "Result:\n" << dst << "\n";
128 }
129
ExampleMulInt8GetRawAccumulators(ruy::Context * context)130 void ExampleMulInt8GetRawAccumulators(ruy::Context *context) {
131 const std::int8_t lhs_data[] = {1, 2, 3, 4};
132 const std::int8_t rhs_data[] = {1, 2, 3, 4};
133 std::int32_t dst_data[4];
134
135 ruy::Matrix<std::int8_t> lhs;
136 ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
137 lhs.set_data(lhs_data);
138 ruy::Matrix<std::int8_t> rhs;
139 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
140 rhs.set_data(rhs_data);
141 ruy::Matrix<std::int32_t> dst;
142 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
143 dst.set_data(dst_data);
144
145 // When Dst is int32, mul_params is unused.
146 ruy::MulParams<std::int32_t, std::int32_t> mul_params;
147 ruy::Mul(lhs, rhs, mul_params, context, &dst);
148
149 std::cout << "Example Mul, returning raw int32 accumulators:\n";
150 std::cout << "LHS:\n" << lhs;
151 std::cout << "RHS:\n" << rhs;
152 std::cout << "Result:\n" << dst << "\n";
153 }
154
ExampleMulInt8TimesInt16PerChannelQuantized(ruy::Context * context)155 void ExampleMulInt8TimesInt16PerChannelQuantized(ruy::Context *context) {
156 const std::int8_t lhs_data[] = {1, 2, 3, 4};
157 const std::int16_t rhs_data[] = {1000, 2000, 3000, 4000};
158 const std::int32_t multiplier_data[] = {3 << 28, 5 << 28};
159 const int exponent_data[] = {1, -2};
160 std::int16_t dst_data[4];
161
162 ruy::Matrix<std::int8_t> lhs;
163 ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
164 lhs.set_data(lhs_data);
165 ruy::Matrix<std::int16_t> rhs;
166 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
167 rhs.set_data(rhs_data);
168 ruy::Matrix<std::int16_t> dst;
169 ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
170 dst.set_data(dst_data);
171
172 ruy::MulParams<std::int32_t, std::int16_t> mul_params;
173 mul_params.set_multiplier_fixedpoint_perchannel(multiplier_data);
174 mul_params.set_multiplier_exponent_perchannel(exponent_data);
175 ruy::Mul(lhs, rhs, mul_params, context, &dst);
176
177 std::cout << "Example Mul, int8 times int16 quantized with per-channel "
178 "multipliers\n";
179 std::cout << "LHS:\n" << lhs;
180 std::cout << "RHS:\n" << rhs;
181 std::cout << "Result:\n" << dst << "\n";
182 }
183
main()184 int main() {
185 ruy::Context context;
186 ExampleMulFloat(&context);
187 ExampleMulFloatWithBiasAddAndClamp(&context);
188 ExampleMulUint8AsymmetricQuantized(&context);
189 ExampleMulInt8PerChannelQuantized(&context);
190 ExampleMulInt8GetRawAccumulators(&context);
191 ExampleMulInt8TimesInt16PerChannelQuantized(&context);
192 }
193