1 /*
2 * Copyright (c) 2018-2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/core/NEON/kernels/NEFuseBatchNormalizationKernel.h"
25 #include "src/cpu/kernels/fuse_batch_normalization/list.h"
26
27 #include "arm_compute/core/Helpers.h"
28 #include "arm_compute/core/ITensor.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/Utils.h"
31 #include "arm_compute/core/Validate.h"
32 #include "arm_compute/core/Window.h"
33 #include "src/common/cpuinfo/CpuIsaInfo.h"
34 #include "src/core/CPP/Validate.h"
35 #include "src/core/NEON/wrapper/wrapper.h"
36 #include "src/core/common/Registrars.h"
37 #include "src/core/helpers/AutoConfiguration.h"
38 #include "src/core/helpers/WindowHelpers.h"
39
40 #include <map>
41
42 namespace arm_compute
43 {
44 namespace
45 {
46 struct FuseBatchNormalizeSelectorData
47 {
48 DataType dt;
49 DataLayout dl;
50 FuseBatchNormalizationType fbn_type;
51 cpuinfo::CpuIsaInfo isa;
52 };
53
54 using FBNSelectorPtr = std::add_pointer<bool(const FuseBatchNormalizeSelectorData &data)>::type;
55 using FBNUKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, ITensor *,
56 const ITensor *, const ITensor *, const ITensor *, const ITensor *, float, const Window &)>::type;
57
58 struct FBNUKernel
59 {
60 const char *name;
61 const FBNSelectorPtr is_selected;
62 FBNUKernelPtr ukernel;
63 };
64
65 static const FBNUKernel available_kernels[] =
66 {
67 {
68 "fused_batch_normalization_conv_NHWC_F16",
69 [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0202() 70 {
71 return data.dt == DataType::F16 && data.dl == DataLayout::NHWC && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
72 },
73 REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_conv_f16)
74 },
75 {
76 "fused_batch_normalization_conv_NCHW_F16",
77 [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0302() 78 {
79 return data.dt == DataType::F16 && data.dl == DataLayout::NCHW && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
80 },
81 REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_conv_f16)
82 },
83 {
84 "fused_batch_normalization_dwc_NHWC_F16",
85 [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0402() 86 {
87 return data.dt == DataType::F16 && data.dl == DataLayout::NHWC && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
88 },
89 REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nhwc_f16)
90 },
91 {
92 "fused_batch_normalization_dwc_NCHW_F16",
93 [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0502() 94 {
95 return data.dt == DataType::F16 && data.dl == DataLayout::NCHW && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
96 },
97 REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nchw_f16)
98 },
99 {
100 "fused_batch_normalization_conv_NHWC_F32",
101 [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0602() 102 {
103 return data.dt == DataType::F32 && data.dl == DataLayout::NHWC && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
104 },
105 REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_conv_f32)
106 },
107 {
108 "fused_batch_normalization_conv_NCHW_F32",
109 [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0702() 110 {
111 return data.dt == DataType::F32 && data.dl == DataLayout::NCHW && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
112 },
113 REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_conv_f32)
114 },
115 {
116 "fused_batch_normalization_dwc_NHWC_F32",
117 [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0802() 118 {
119 return data.dt == DataType::F32 && data.dl == DataLayout::NHWC && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
120 },
121 REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nhwc_f32)
122 },
123 {
124 "fused_batch_normalization_dwc_NCHW_F32",
125 [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0902() 126 {
127 return data.dt == DataType::F32 && data.dl == DataLayout::NCHW && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
128 },
129 REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nchw_f32)
130 }
131 };
132
133 /** Micro-kernel selector
134 *
135 * @param[in] data Selection data passed to help pick the appropriate micro-kernel
136 *
137 * @param[in]
138 *
139 * @return A matching micro-kernel else nullptr
140 */
get_implementation(const FuseBatchNormalizeSelectorData & data)141 const FBNUKernel *get_implementation(const FuseBatchNormalizeSelectorData &data)
142 {
143 for(const auto &uk : available_kernels)
144 {
145 if(uk.is_selected(data))
146 {
147 return &uk;
148 }
149 }
150 return nullptr;
151 }
152
validate_arguments(const ITensorInfo * input_weights,const ITensorInfo * bn_mean,const ITensorInfo * bn_var,const ITensorInfo * fused_weights,const ITensorInfo * fused_bias,const ITensorInfo * input_bias,const ITensorInfo * bn_beta,const ITensorInfo * bn_gamma,float epsilon,FuseBatchNormalizationType fbn_type)153 Status validate_arguments(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var,
154 const ITensorInfo *fused_weights, const ITensorInfo *fused_bias,
155 const ITensorInfo *input_bias, const ITensorInfo *bn_beta, const ITensorInfo *bn_gamma,
156 float epsilon, FuseBatchNormalizationType fbn_type)
157 {
158 ARM_COMPUTE_UNUSED(epsilon);
159 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input_weights, bn_mean, bn_var);
160 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input_weights);
161 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_weights, 1, DataType::F16, DataType::F32);
162 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_var);
163 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_mean, bn_var);
164 ARM_COMPUTE_RETURN_ERROR_ON(input_bias == nullptr && fused_bias == nullptr);
165 ARM_COMPUTE_RETURN_ERROR_ON(bn_mean->num_dimensions() > 1);
166
167 if(fbn_type == FuseBatchNormalizationType::CONVOLUTION)
168 {
169 ARM_COMPUTE_RETURN_ERROR_ON(input_weights->dimension(3) != bn_mean->dimension(0));
170 }
171 else
172 {
173 const size_t channel_idx = get_data_layout_dimension_index(input_weights->data_layout(), DataLayoutDimension::CHANNEL);
174 ARM_COMPUTE_RETURN_ERROR_ON(input_weights->dimension(channel_idx) != bn_mean->dimension(0));
175 }
176 // Validate bias
177 if(input_bias != nullptr)
178 {
179 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, input_bias);
180 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, input_bias);
181 }
182 // Validate beta
183 if(bn_beta != nullptr)
184 {
185 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_beta);
186 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_beta);
187 }
188 // Validate gamma
189 if(bn_gamma != nullptr)
190 {
191 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_gamma);
192 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_gamma);
193 }
194
195 // Validate output weights
196 if(fused_weights != nullptr && fused_weights->total_size() != 0)
197 {
198 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_weights, fused_weights);
199 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input_weights, fused_weights);
200 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, fused_weights);
201 }
202 // Validate output bias
203 if(fused_bias != nullptr && fused_bias->total_size() != 0)
204 {
205 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, fused_bias);
206 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, fused_bias);
207 }
208
209 return Status{};
210 }
211
212 } // namespace
213
NEFuseBatchNormalizationKernel()214 NEFuseBatchNormalizationKernel::NEFuseBatchNormalizationKernel()
215 : _input_weights(nullptr), _input_bias(nullptr), _bn_mean(nullptr), _bn_var(nullptr), _bn_gamma(nullptr), _bn_beta(nullptr), _fused_weights(nullptr), _fused_bias(nullptr), _epsilon(),
216 _run_in_place_weights(false), _run_in_place_bias(false), _func(nullptr)
217 {
218 }
219
configure(const ITensor * input_weights,const ITensor * bn_mean,const ITensor * bn_var,ITensor * fused_weights,ITensor * fused_bias,const ITensor * input_bias,const ITensor * bn_beta,const ITensor * bn_gamma,float epsilon,FuseBatchNormalizationType fbn_type)220 void NEFuseBatchNormalizationKernel::configure(const ITensor *input_weights, const ITensor *bn_mean, const ITensor *bn_var,
221 ITensor *fused_weights, ITensor *fused_bias,
222 const ITensor *input_bias, const ITensor *bn_beta, const ITensor *bn_gamma,
223 float epsilon, FuseBatchNormalizationType fbn_type)
224 {
225 ARM_COMPUTE_ERROR_ON_NULLPTR(input_weights, bn_mean, bn_var);
226
227 _input_weights = input_weights;
228 _input_bias = input_bias;
229 _bn_mean = bn_mean;
230 _bn_var = bn_var;
231 _bn_beta = bn_beta;
232 _bn_gamma = bn_gamma;
233 _fused_weights = fused_weights;
234 _fused_bias = fused_bias;
235 _epsilon = epsilon;
236
237 _run_in_place_weights = (fused_weights == nullptr) || (fused_weights == input_weights);
238 _run_in_place_bias = (fused_bias == nullptr) || (input_bias != nullptr && fused_bias == input_bias);
239
240 // Auto initialize outputs
241 if(_fused_weights != nullptr)
242 {
243 // Output tensor auto initialization if not yet initialized
244 auto_init_if_empty(*_fused_weights->info(), *_input_weights->info()->clone());
245 }
246 if(_fused_bias != nullptr)
247 {
248 // Output tensor auto initialization if not yet initialized
249 auto_init_if_empty(*_fused_bias->info(), *_bn_mean->info()->clone());
250 }
251
252 // Validate arguments
253 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input_weights->info(), bn_mean->info(), bn_var->info(),
254 (fused_weights != nullptr) ? fused_weights->info() : nullptr,
255 (fused_bias != nullptr) ? fused_bias->info() : nullptr,
256 (input_bias != nullptr) ? input_bias->info() : nullptr,
257 (bn_beta != nullptr) ? bn_beta->info() : nullptr,
258 (bn_gamma != nullptr) ? bn_gamma->info() : nullptr,
259 epsilon, fbn_type));
260
261 const auto *uk = get_implementation(FuseBatchNormalizeSelectorData{ input_weights->info()->data_type(), input_weights->info()->data_layout(), fbn_type, CPUInfo::get().get_isa() });
262 ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
263 ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
264 _func = uk->ukernel;
265
266 // Configure kernel window
267 Window win = calculate_max_window(*input_weights->info());
268 INEKernel::configure(win);
269 }
270
validate(const ITensorInfo * input_weights,const ITensorInfo * bn_mean,const ITensorInfo * bn_var,const ITensorInfo * fused_weights,const ITensorInfo * fused_bias,const ITensorInfo * input_bias,const ITensorInfo * bn_beta,const ITensorInfo * bn_gamma,float epsilon,FuseBatchNormalizationType fbn_type)271 Status NEFuseBatchNormalizationKernel::validate(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var,
272 const ITensorInfo *fused_weights, const ITensorInfo *fused_bias,
273 const ITensorInfo *input_bias, const ITensorInfo *bn_beta, const ITensorInfo *bn_gamma,
274 float epsilon, FuseBatchNormalizationType fbn_type)
275 {
276 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input_weights, bn_mean, bn_var, fused_weights, fused_bias, input_bias, bn_beta, bn_gamma, epsilon, fbn_type));
277 return Status{};
278 }
279
run(const Window & window,const ThreadInfo & info)280 void NEFuseBatchNormalizationKernel::run(const Window &window, const ThreadInfo &info)
281 {
282 ARM_COMPUTE_UNUSED(info);
283 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
284 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
285
286 ARM_COMPUTE_ERROR_ON(_func == nullptr);
287 (*_func)(_input_weights, _input_bias, _fused_weights, _fused_bias, _bn_mean, _bn_var, _bn_beta, _bn_gamma, _epsilon, window);
288 }
289 } // namespace arm_compute
290