xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h"
25 
26 #include "arm_compute/core/Helpers.h"
27 #include "arm_compute/core/TensorInfo.h"
28 #include "arm_compute/core/Utils.h"
29 #include "arm_compute/core/Validate.h"
30 #include "arm_compute/core/Window.h"
31 #include "src/core/CPP/Validate.h"
32 #include "src/core/NEON/NEFixedPoint.h"
33 #include "src/core/NEON/NEMath.h"
34 #include "src/core/helpers/AutoConfiguration.h"
35 #include "src/core/helpers/WindowHelpers.h"
36 
37 #include "src/core/NEON/kernels/detail/NEActivationFunctionDetail.h"
38 #include "src/core/NEON/wrapper/wrapper.h"
39 
40 #include "src/core/NEON/kernels/batchnormalization/impl/list.h"
41 #include "src/core/common/Registrars.h"
42 
43 #include <map>
44 
45 namespace arm_compute
46 {
47 namespace
48 {
49 struct BatchNormalizationSelectorData
50 {
51     DataType       dt;
52     const CPUInfo &ci;
53 };
54 using BatchNormalizationSelectorPtr = std::add_pointer<bool(const BatchNormalizationSelectorData &data)>::type;
55 using BatchNormalizationKernelPtr   = std::add_pointer<void(ITensor *, ITensor *, const ITensor *, const ITensor *, const ITensor *, const ITensor *,
56                                                             float, ActivationLayerInfo &, const Window &)>::type;
57 
58 struct BatchNormalizationKernel
59 {
60     const char                         *name;
61     const BatchNormalizationSelectorPtr is_selected;
62     BatchNormalizationKernelPtr         ukernel;
63 };
64 
65 static const BatchNormalizationKernel available_kernels[] =
66 {
67 #if defined(ARM_COMPUTE_ENABLE_SVE)
68     {
69         "sve_fp16_batch_normalization",
__anon0df002e40202() 70         [](const BatchNormalizationSelectorData & data) { return data.dt == DataType::F16 && data.ci.has_sve(); },
71         REGISTER_FP16_SVE(arm_compute::cpu::fp16_sve_batch_normalization)
72     },
73     {
74         "sve_fp32_batch_normalization",
__anon0df002e40302() 75         [](const BatchNormalizationSelectorData & data) { return data.dt == DataType::F32 && data.ci.has_sve(); },
76         REGISTER_FP32_SVE(arm_compute::cpu::fp32_sve_batch_normalization)
77     },
78 #endif /* !defined(ARM_COMPUTE_ENABLE_SVE) */
79 #if defined(ARM_COMPUTE_ENABLE_NEON)
80 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
81     {
82         "neon_fp16_batch_normalization",
__anon0df002e40402() 83         [](const BatchNormalizationSelectorData & data) { return data.dt == DataType::F16; },
84         REGISTER_FP16_NEON(arm_compute::cpu::fp16_neon_batch_normalization)
85     },
86 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
87     {
88         "neon_fp32_batch_normalization",
__anon0df002e40502() 89         [](const BatchNormalizationSelectorData & data) { return data.dt == DataType::F32; },
90         REGISTER_FP32_NEON(arm_compute::cpu::fp32_neon_batch_normalization)
91     },
92 #endif /* !defined(ARM_COMPUTE_ENABLE_NEON) */
93 };
94 
get_implementation(const BatchNormalizationSelectorData & data)95 const BatchNormalizationKernel *get_implementation(const BatchNormalizationSelectorData &data)
96 {
97     for(const auto &uk : available_kernels)
98     {
99         if(uk.is_selected(data))
100         {
101             return &uk;
102         }
103     }
104     return nullptr;
105 }
106 
107 Status
validate_arguments(const ITensorInfo * input,const ITensorInfo * output,const ITensorInfo * mean,const ITensorInfo * var,const ITensorInfo * beta,const ITensorInfo * gamma,float epsilon,ActivationLayerInfo act_info)108 validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *mean, const ITensorInfo *var,
109                    const ITensorInfo *beta, const ITensorInfo *gamma, float epsilon, ActivationLayerInfo act_info)
110 {
111     ARM_COMPUTE_UNUSED(epsilon);
112 
113     const auto *uk = get_implementation(BatchNormalizationSelectorData{ input->data_type(), CPUInfo::get() });
114     ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
115 
116     if(act_info.enabled())
117     {
118         ActivationLayerInfo::ActivationFunction act = act_info.activation();
119         ARM_COMPUTE_RETURN_ERROR_ON(act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::RELU
120                                     && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU
121                                     && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
122         ARM_COMPUTE_RETURN_ERROR_ON(act_info.b() > act_info.a());
123     }
124 
125     if(nullptr != output)
126     {
127         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
128         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
129         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
130     }
131 
132     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var);
133     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var);
134     if(beta != nullptr)
135     {
136         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, beta);
137         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, beta);
138     }
139     if(gamma != nullptr)
140     {
141         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, gamma);
142         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma);
143     }
144     ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)) != mean->dimension(0));
145 
146     return Status{};
147 }
148 } //namespace
149 
150 template <typename T, bool fused_activation, typename F>
batch_normalization_nchw(const Window & window)151 void NEBatchNormalizationLayerKernel::batch_normalization_nchw(const Window &window)
152 {
153     /** SIMD vector tag type. */
154     using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
155 
156     const int  window_step_x  = 16 / sizeof(T);
157     const auto window_start_x = static_cast<int>(window.x().start());
158     const auto window_end_x   = static_cast<int>(window.x().end());
159 
160     Window win_to_use = window;
161     win_to_use.set(Window::DimX, Window::Dimension(0, 1, 1));
162 
163     Iterator input(_input, win_to_use);
164     Iterator output(_output, win_to_use);
165 
166     F activation_functor(_act_info);
167 
168     // Hold information about the current feature map we are iterating.
169     // Only compute denominator and constants once per feature map.
170     int slice = -1;
171 
172     const auto input_mean  = reinterpret_cast<const T *>(_mean->ptr_to_element(Coordinates(0, 0)));
173     const auto input_var   = reinterpret_cast<const T *>(_var->ptr_to_element(Coordinates(0, 0)));
174     const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const T *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
175     const auto input_beta  = (_beta != nullptr) ? reinterpret_cast<const T *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
176 
177     T mean        = static_cast<T>(0);
178     T var         = static_cast<T>(0);
179     T gamma       = static_cast<T>(1);
180     T beta        = static_cast<T>(0);
181     T denominator = static_cast<T>(0);
182 
183     auto       mean_vec        = wrapper::vdup_n(mean, ExactTagType{});
184     auto       var_vec         = wrapper::vdup_n(var, ExactTagType{});
185     auto       gamma_vec       = wrapper::vdup_n(gamma, ExactTagType{});
186     auto       beta_vec        = wrapper::vdup_n(beta, ExactTagType{});
187     auto       denominator_vec = wrapper::vdup_n(denominator, ExactTagType{});
188     const auto epsilon_vec     = wrapper::vdup_n(static_cast<T>(_epsilon), ExactTagType{});
189     execute_window_loop(win_to_use, [&](const Coordinates & id)
190     {
191         const auto input_ptr  = reinterpret_cast<const T *>(input.ptr());
192         const auto output_ptr = reinterpret_cast<T *>(output.ptr());
193 
194         if(slice != id.z())
195         {
196             mean     = input_mean[id.z()];
197             var      = input_var[id.z()];
198             mean_vec = wrapper::vdup_n(mean, ExactTagType{});
199             var_vec  = wrapper::vdup_n(var, ExactTagType{});
200             if(input_gamma != nullptr)
201             {
202                 gamma     = input_gamma[id.z()];
203                 gamma_vec = wrapper::vdup_n(gamma, ExactTagType{});
204             }
205             if(input_beta != nullptr)
206             {
207                 beta     = input_beta[id.z()];
208                 beta_vec = wrapper::vdup_n(beta, ExactTagType{});
209             }
210 
211             // Calculate denominator
212             denominator_vec = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec));
213             denominator     = wrapper::vgetlane(denominator_vec, 0);
214             slice           = id.z();
215         }
216 
217         // Perform core calculations using vector operations
218         int x = window_start_x;
219         for(; x <= (window_end_x - window_step_x); x += window_step_x)
220         {
221             // Calculate x bar
222             const auto numerator = wrapper::vsub(wrapper::vloadq(input_ptr + x), mean_vec);
223             const auto x_bar     = wrapper::vmul(numerator, denominator_vec);
224             auto       res       = wrapper::vmla(beta_vec, x_bar, gamma_vec);
225 
226             // Perform fused activation
227             if(fused_activation)
228             {
229                 activation_functor(res);
230             }
231 
232             // Store results
233             wrapper::vstore(output_ptr + x, res);
234         }
235 
236         // Compute left-over elements
237         for(; x < window_end_x; ++x)
238         {
239             const T numerator = input_ptr[x] - mean;
240             const T x_bar     = numerator * denominator;
241             T       res       = beta + x_bar * gamma;
242 
243             // Perform fused activation
244             if(fused_activation)
245             {
246                 activation_functor(res);
247             }
248 
249             // Store results
250             *(output_ptr + x) = res;
251         }
252     },
253     input, output);
254 }
255 
configure_non_fused()256 void NEBatchNormalizationLayerKernel::configure_non_fused()
257 {
258     switch(_input->info()->data_type())
259     {
260 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
261         case DataType::F16:
262             _func = &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, false, detail::dummy<float16_t, 8>>;
263             break;
264 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
265         case DataType::F32:
266             _func = &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, false, detail::dummy<float, 4>>;
267             break;
268         default:
269             ARM_COMPUTE_ERROR("Element size not supported");
270             break;
271     }
272 }
273 
configure_fused()274 void NEBatchNormalizationLayerKernel::configure_fused()
275 {
276     // NCHW Fused Batched Normalization with activation functions : FP32
277     static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f32_nchw =
278     {
279         { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, true, detail::relu<float, 4>> },
280         { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, true, detail::brelu<float, 4>> },
281         { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, true, detail::lubrelu<float, 4>> }
282     };
283 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
284     // NCHW Fused Batched Normalization with activation functions : FP16
285     static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f16_nchw =
286     {
287         { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, true, detail::relu<float16_t, 8>> },
288         { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, true, detail::brelu<float16_t, 8>> },
289         { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, true, detail::lubrelu<float16_t, 8>> }
290     };
291 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
292 
293     switch(_input->info()->data_type())
294     {
295 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
296         case DataType::F16:
297             _func = bn_fused_map_f16_nchw[_act_info.activation()];
298             break;
299 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
300         case DataType::F32:
301             _func = bn_fused_map_f32_nchw[_act_info.activation()];
302             break;
303         default:
304             ARM_COMPUTE_ERROR("Element size not supported");
305             break;
306     }
307 }
308 
NEBatchNormalizationLayerKernel()309 NEBatchNormalizationLayerKernel::NEBatchNormalizationLayerKernel()
310     : _func(nullptr), _input(nullptr), _output(nullptr), _mean(nullptr), _var(nullptr), _gamma(nullptr), _beta(nullptr), _epsilon(), _act_info()
311 {
312 }
313 
configure(ITensor * input,ITensor * output,const ITensor * mean,const ITensor * var,const ITensor * beta,const ITensor * gamma,float epsilon,ActivationLayerInfo act_info)314 void NEBatchNormalizationLayerKernel::configure(ITensor *input, ITensor *output,
315                                                 const ITensor *mean, const ITensor *var,
316                                                 const ITensor *beta, const ITensor *gamma,
317                                                 float epsilon, ActivationLayerInfo act_info)
318 {
319     ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, var);
320 
321     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr,
322                                                   mean->info(), var->info(),
323                                                   (beta != nullptr) ? beta->info() : nullptr,
324                                                   (gamma != nullptr) ? gamma->info() : nullptr,
325                                                   epsilon, act_info));
326 
327     _input    = input;
328     _output   = input;
329     _mean     = mean;
330     _var      = var;
331     _gamma    = gamma;
332     _beta     = beta;
333     _epsilon  = epsilon;
334     _act_info = act_info;
335 
336     const bool run_in_place = (output == nullptr) || (output == input);
337     if(!run_in_place)
338     {
339         _output = output;
340     }
341 
342     // Configure activation function to run
343     const bool is_nchw = _input->info()->data_layout() == DataLayout::NCHW;
344     if(is_nchw)
345     {
346         if(_act_info.enabled())
347         {
348             configure_fused();
349         }
350         else
351         {
352             configure_non_fused();
353         }
354     }
355 
356     // Configure kernel window
357     Window win = calculate_max_window(*input->info(), Steps());
358     INEKernel::configure(win);
359 
360     if(output != nullptr)
361     {
362         // Output auto initialization if not yet initialized
363         auto_init_if_empty(*output->info(), *input->info()->clone());
364     }
365 }
366 
validate(const ITensorInfo * input,const ITensorInfo * output,const ITensorInfo * mean,const ITensorInfo * var,const ITensorInfo * beta,const ITensorInfo * gamma,float epsilon,ActivationLayerInfo act_info)367 Status NEBatchNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output,
368                                                  const ITensorInfo *mean, const ITensorInfo *var,
369                                                  const ITensorInfo *beta, const ITensorInfo *gamma,
370                                                  float epsilon, ActivationLayerInfo act_info)
371 {
372     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mean, var, beta, gamma, epsilon, act_info));
373 
374     return Status{};
375 }
376 
run(const Window & window,const ThreadInfo & info)377 void NEBatchNormalizationLayerKernel::run(const Window &window, const ThreadInfo &info)
378 {
379     ARM_COMPUTE_UNUSED(info);
380     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
381     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
382     ARM_COMPUTE_ERROR_ON(_func == nullptr && _input->info()->data_layout() == DataLayout::NCHW);
383 
384     const bool is_nchw = _input->info()->data_layout() == DataLayout::NCHW;
385     if(is_nchw)
386     {
387         (this->*_func)(window);
388     }
389     else
390     {
391         const auto *uk = get_implementation(BatchNormalizationSelectorData{ _input->info()->data_type(), CPUInfo::get() });
392         uk->ukernel(_input, _output, _mean, _var, _beta, _gamma, _epsilon, _act_info, window);
393     }
394 }
395 } // namespace arm_compute
396