xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/ConvImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "RefWorkloadUtils.hpp"
9 #include "TensorBufferArrayView.hpp"
10 #include "BaseIterator.hpp"
11 #include "Decoders.hpp"
12 #include "Encoders.hpp"
13 
14 #include <armnn/Tensor.hpp>
15 
16 #include <armnnUtils/DataLayoutIndexed.hpp>
17 
18 #include <cmath>
19 #include <limits>
20 
21 namespace armnn
22 {
23 
24 /// Performs multiplication of an integer with a multiplier which is less than one,
25 /// using quantized integer arithmetic which is consistent with AndroidNN's CPU executor.
26 struct QuantizedMultiplierSmallerThanOne
27 {
28 public:
29     /// Constructs a QuantizedMultiplierSmallerThanOne which will multiply by the given multiplier.
30     /// This stores the appropriate integer quantities (derived from the given multiplier) for later use.
31     /// The implementation of this function is adapted from Android NN's QuantizeMultiplierSmallerThanOne().
32     QuantizedMultiplierSmallerThanOne(float multiplier);
33 
34     /// The implementation of this function is adapted from Android NN's MultiplyByQuantizedMultiplierSmallerThanOne().
35     int32_t operator*(int32_t rhs) const;
36 
37 private:
38     /// The implementation of this function is adapted from gemmlowp's SaturatingRoundingDoublingHighMul().
39     static int32_t SaturatingRoundingDoublingHighMul(int32_t a, int32_t b);
40 
41     /// The implementation of this function is adapted from gemmlowp's RoundingDivideByPOT().
42     static int32_t RoundingDivideByPOT(int32_t x, int exponent);
43 
44     int32_t m_Multiplier;
45     int32_t m_RightShift;
46 };
47 
48 void Convolve(const TensorShape& rInputShape,
49               Decoder<float>& rInputDecoder,
50               const TensorShape& rOutputShape,
51               Encoder<float>& rOutputEncoder,
52               const TensorShape& rFilterShape,
53               Decoder<float>& rFilterDecoder,
54               bool biasEnabled,
55               Decoder<float>* pBiasDecoder,
56               DataLayout dataLayout,
57               unsigned int paddingTop,
58               unsigned int paddingLeft,
59               unsigned int xStride,
60               unsigned int yStride,
61               unsigned int xDilation,
62               unsigned int yDilation,
63               bool depthwise = false);
64 } //namespace armnn
65