xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/InstanceNorm.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "InstanceNorm.hpp"
7 #include "RefWorkloadUtils.hpp"
8 
9 #include <armnn/Tensor.hpp>
10 
11 #include <armnnUtils/DataLayoutIndexed.hpp>
12 
13 #include <cmath>
14 
15 namespace armnn
16 {
17 
InstanceNorm(const InstanceNormalizationQueueDescriptor & data,const TensorInfo & inputInfo,Decoder<float> & inputDecoder,Encoder<float> & outputEncoder)18 void InstanceNorm(const InstanceNormalizationQueueDescriptor& data,
19                   const TensorInfo& inputInfo,
20                   Decoder<float>& inputDecoder,
21                   Encoder<float>& outputEncoder)
22 {
23     const TensorShape inputShape = inputInfo.GetShape();
24 
25     armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout);
26 
27     unsigned int inputBatches  = inputShape[0];
28     unsigned int inputHeight   = inputShape[dataLayout.GetHeightIndex()];
29     unsigned int inputWidth    = inputShape[dataLayout.GetWidthIndex()];
30     unsigned int inputChannels = inputShape[dataLayout.GetChannelsIndex()];
31 
32     float beta  = data.m_Parameters.m_Beta;
33     float eps   = data.m_Parameters.m_Eps;
34     float gamma = data.m_Parameters.m_Gamma;
35 
36     for (unsigned int n = 0; n < inputBatches; ++n)
37     {
38         for (unsigned int c = 0; c < inputChannels; ++c)
39         {
40             float mean = 0, var = 0;
41 
42             //Calculate Mean
43             for (unsigned int h = 0; h < inputHeight; h++)
44             {
45                 for (unsigned int w = 0; w < inputWidth; w++)
46                 {
47                     unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
48 
49                     inputDecoder[index];
50                     float value = inputDecoder.Get();
51                     mean += value;
52                 }
53             }
54             mean /= static_cast<float>(inputHeight * inputWidth);
55 
56             //Calculate Variance
57             for (unsigned int h = 0; h < inputHeight; h++)
58             {
59                 for (unsigned int w = 0; w < inputWidth; w++)
60                 {
61                     unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
62 
63                     inputDecoder[index];
64                     float value = inputDecoder.Get();
65                     var += (value - mean) * (value - mean);
66                 }
67             }
68             var /= static_cast<float>(inputHeight * inputWidth);
69 
70             // Apply Instance Normalisation
71             for (unsigned int h = 0; h < inputHeight; ++h)
72             {
73                 for (unsigned int w = 0; w < inputWidth; ++w)
74                 {
75                     unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
76                     inputDecoder[index];
77                     outputEncoder[index];
78                     outputEncoder.Set((inputDecoder.Get() - mean) * gamma /  std::sqrt ( var + eps) + beta);
79                 }
80 
81             }
82         }
83     }
84 }
85 
86 } // namespace armnn
87