xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/NEFuseBatchNormalizationKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/NEON/kernels/NEFuseBatchNormalizationKernel.h"
25 #include "src/cpu/kernels/fuse_batch_normalization/list.h"
26 
27 #include "arm_compute/core/Helpers.h"
28 #include "arm_compute/core/ITensor.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/Utils.h"
31 #include "arm_compute/core/Validate.h"
32 #include "arm_compute/core/Window.h"
33 #include "src/common/cpuinfo/CpuIsaInfo.h"
34 #include "src/core/CPP/Validate.h"
35 #include "src/core/NEON/wrapper/wrapper.h"
36 #include "src/core/common/Registrars.h"
37 #include "src/core/helpers/AutoConfiguration.h"
38 #include "src/core/helpers/WindowHelpers.h"
39 
40 #include <map>
41 
42 namespace arm_compute
43 {
44 namespace
45 {
46 struct FuseBatchNormalizeSelectorData
47 {
48     DataType                   dt;
49     DataLayout                 dl;
50     FuseBatchNormalizationType fbn_type;
51     cpuinfo::CpuIsaInfo        isa;
52 };
53 
54 using FBNSelectorPtr = std::add_pointer<bool(const FuseBatchNormalizeSelectorData &data)>::type;
55 using FBNUKernelPtr  = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, ITensor *,
56                                              const ITensor *, const ITensor *, const ITensor *, const ITensor *, float, const Window &)>::type;
57 
58 struct FBNUKernel
59 {
60     const char          *name;
61     const FBNSelectorPtr is_selected;
62     FBNUKernelPtr        ukernel;
63 };
64 
65 static const FBNUKernel available_kernels[] =
66 {
67     {
68         "fused_batch_normalization_conv_NHWC_F16",
69         [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0202() 70         {
71             return data.dt == DataType::F16 && data.dl == DataLayout::NHWC && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
72         },
73         REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_conv_f16)
74     },
75     {
76         "fused_batch_normalization_conv_NCHW_F16",
77         [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0302() 78         {
79             return data.dt == DataType::F16 && data.dl == DataLayout::NCHW && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
80         },
81         REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_conv_f16)
82     },
83     {
84         "fused_batch_normalization_dwc_NHWC_F16",
85         [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0402() 86         {
87             return data.dt == DataType::F16 && data.dl == DataLayout::NHWC && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
88         },
89         REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nhwc_f16)
90     },
91     {
92         "fused_batch_normalization_dwc_NCHW_F16",
93         [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0502() 94         {
95             return data.dt == DataType::F16 && data.dl == DataLayout::NCHW && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
96         },
97         REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nchw_f16)
98     },
99     {
100         "fused_batch_normalization_conv_NHWC_F32",
101         [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0602() 102         {
103             return data.dt == DataType::F32 && data.dl == DataLayout::NHWC && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
104         },
105         REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_conv_f32)
106     },
107     {
108         "fused_batch_normalization_conv_NCHW_F32",
109         [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0702() 110         {
111             return data.dt == DataType::F32 && data.dl == DataLayout::NCHW && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
112         },
113         REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_conv_f32)
114     },
115     {
116         "fused_batch_normalization_dwc_NHWC_F32",
117         [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0802() 118         {
119             return data.dt == DataType::F32 && data.dl == DataLayout::NHWC && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
120         },
121         REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nhwc_f32)
122     },
123     {
124         "fused_batch_normalization_dwc_NCHW_F32",
125         [](const FuseBatchNormalizeSelectorData & data)
__anon040f663a0902() 126         {
127             return data.dt == DataType::F32 && data.dl == DataLayout::NCHW && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
128         },
129         REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nchw_f32)
130     }
131 };
132 
133 /** Micro-kernel selector
134  *
135  * @param[in] data Selection data passed to help pick the appropriate micro-kernel
136  *
137  * @param[in]
138  *
139  * @return A matching micro-kernel else nullptr
140  */
get_implementation(const FuseBatchNormalizeSelectorData & data)141 const FBNUKernel *get_implementation(const FuseBatchNormalizeSelectorData &data)
142 {
143     for(const auto &uk : available_kernels)
144     {
145         if(uk.is_selected(data))
146         {
147             return &uk;
148         }
149     }
150     return nullptr;
151 }
152 
validate_arguments(const ITensorInfo * input_weights,const ITensorInfo * bn_mean,const ITensorInfo * bn_var,const ITensorInfo * fused_weights,const ITensorInfo * fused_bias,const ITensorInfo * input_bias,const ITensorInfo * bn_beta,const ITensorInfo * bn_gamma,float epsilon,FuseBatchNormalizationType fbn_type)153 Status validate_arguments(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var,
154                           const ITensorInfo *fused_weights, const ITensorInfo *fused_bias,
155                           const ITensorInfo *input_bias, const ITensorInfo *bn_beta, const ITensorInfo *bn_gamma,
156                           float epsilon, FuseBatchNormalizationType fbn_type)
157 {
158     ARM_COMPUTE_UNUSED(epsilon);
159     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input_weights, bn_mean, bn_var);
160     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input_weights);
161     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_weights, 1, DataType::F16, DataType::F32);
162     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_var);
163     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_mean, bn_var);
164     ARM_COMPUTE_RETURN_ERROR_ON(input_bias == nullptr && fused_bias == nullptr);
165     ARM_COMPUTE_RETURN_ERROR_ON(bn_mean->num_dimensions() > 1);
166 
167     if(fbn_type == FuseBatchNormalizationType::CONVOLUTION)
168     {
169         ARM_COMPUTE_RETURN_ERROR_ON(input_weights->dimension(3) != bn_mean->dimension(0));
170     }
171     else
172     {
173         const size_t channel_idx = get_data_layout_dimension_index(input_weights->data_layout(), DataLayoutDimension::CHANNEL);
174         ARM_COMPUTE_RETURN_ERROR_ON(input_weights->dimension(channel_idx) != bn_mean->dimension(0));
175     }
176     // Validate bias
177     if(input_bias != nullptr)
178     {
179         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, input_bias);
180         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, input_bias);
181     }
182     // Validate beta
183     if(bn_beta != nullptr)
184     {
185         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_beta);
186         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_beta);
187     }
188     // Validate gamma
189     if(bn_gamma != nullptr)
190     {
191         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_gamma);
192         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_gamma);
193     }
194 
195     // Validate output weights
196     if(fused_weights != nullptr && fused_weights->total_size() != 0)
197     {
198         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_weights, fused_weights);
199         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input_weights, fused_weights);
200         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, fused_weights);
201     }
202     // Validate output bias
203     if(fused_bias != nullptr && fused_bias->total_size() != 0)
204     {
205         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, fused_bias);
206         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, fused_bias);
207     }
208 
209     return Status{};
210 }
211 
212 } // namespace
213 
NEFuseBatchNormalizationKernel()214 NEFuseBatchNormalizationKernel::NEFuseBatchNormalizationKernel()
215     : _input_weights(nullptr), _input_bias(nullptr), _bn_mean(nullptr), _bn_var(nullptr), _bn_gamma(nullptr), _bn_beta(nullptr), _fused_weights(nullptr), _fused_bias(nullptr), _epsilon(),
216       _run_in_place_weights(false), _run_in_place_bias(false), _func(nullptr)
217 {
218 }
219 
configure(const ITensor * input_weights,const ITensor * bn_mean,const ITensor * bn_var,ITensor * fused_weights,ITensor * fused_bias,const ITensor * input_bias,const ITensor * bn_beta,const ITensor * bn_gamma,float epsilon,FuseBatchNormalizationType fbn_type)220 void NEFuseBatchNormalizationKernel::configure(const ITensor *input_weights, const ITensor *bn_mean, const ITensor *bn_var,
221                                                ITensor *fused_weights, ITensor *fused_bias,
222                                                const ITensor *input_bias, const ITensor *bn_beta, const ITensor *bn_gamma,
223                                                float epsilon, FuseBatchNormalizationType fbn_type)
224 {
225     ARM_COMPUTE_ERROR_ON_NULLPTR(input_weights, bn_mean, bn_var);
226 
227     _input_weights = input_weights;
228     _input_bias    = input_bias;
229     _bn_mean       = bn_mean;
230     _bn_var        = bn_var;
231     _bn_beta       = bn_beta;
232     _bn_gamma      = bn_gamma;
233     _fused_weights = fused_weights;
234     _fused_bias    = fused_bias;
235     _epsilon       = epsilon;
236 
237     _run_in_place_weights = (fused_weights == nullptr) || (fused_weights == input_weights);
238     _run_in_place_bias    = (fused_bias == nullptr) || (input_bias != nullptr && fused_bias == input_bias);
239 
240     // Auto initialize outputs
241     if(_fused_weights != nullptr)
242     {
243         // Output tensor auto initialization if not yet initialized
244         auto_init_if_empty(*_fused_weights->info(), *_input_weights->info()->clone());
245     }
246     if(_fused_bias != nullptr)
247     {
248         // Output tensor auto initialization if not yet initialized
249         auto_init_if_empty(*_fused_bias->info(), *_bn_mean->info()->clone());
250     }
251 
252     // Validate arguments
253     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input_weights->info(), bn_mean->info(), bn_var->info(),
254                                                   (fused_weights != nullptr) ? fused_weights->info() : nullptr,
255                                                   (fused_bias != nullptr) ? fused_bias->info() : nullptr,
256                                                   (input_bias != nullptr) ? input_bias->info() : nullptr,
257                                                   (bn_beta != nullptr) ? bn_beta->info() : nullptr,
258                                                   (bn_gamma != nullptr) ? bn_gamma->info() : nullptr,
259                                                   epsilon, fbn_type));
260 
261     const auto *uk = get_implementation(FuseBatchNormalizeSelectorData{ input_weights->info()->data_type(), input_weights->info()->data_layout(), fbn_type, CPUInfo::get().get_isa() });
262     ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
263     ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
264     _func = uk->ukernel;
265 
266     // Configure kernel window
267     Window win = calculate_max_window(*input_weights->info());
268     INEKernel::configure(win);
269 }
270 
validate(const ITensorInfo * input_weights,const ITensorInfo * bn_mean,const ITensorInfo * bn_var,const ITensorInfo * fused_weights,const ITensorInfo * fused_bias,const ITensorInfo * input_bias,const ITensorInfo * bn_beta,const ITensorInfo * bn_gamma,float epsilon,FuseBatchNormalizationType fbn_type)271 Status NEFuseBatchNormalizationKernel::validate(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var,
272                                                 const ITensorInfo *fused_weights, const ITensorInfo *fused_bias,
273                                                 const ITensorInfo *input_bias, const ITensorInfo *bn_beta, const ITensorInfo *bn_gamma,
274                                                 float epsilon, FuseBatchNormalizationType fbn_type)
275 {
276     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input_weights, bn_mean, bn_var, fused_weights, fused_bias, input_bias, bn_beta, bn_gamma, epsilon, fbn_type));
277     return Status{};
278 }
279 
run(const Window & window,const ThreadInfo & info)280 void NEFuseBatchNormalizationKernel::run(const Window &window, const ThreadInfo &info)
281 {
282     ARM_COMPUTE_UNUSED(info);
283     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
284     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
285 
286     ARM_COMPUTE_ERROR_ON(_func == nullptr);
287     (*_func)(_input_weights, _input_bias, _fused_weights, _fused_bias, _bn_mean, _bn_var, _bn_beta, _bn_gamma, _epsilon, window);
288 }
289 } // namespace arm_compute
290