1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "FullyConnected.hpp"
7
8 #include <armnn/utility/Assert.hpp>
9
10 #include "RefWorkloadUtils.hpp"
11
12 namespace armnn
13 {
14
FullyConnected(const TensorShape & rInputShape,Decoder<float> & rInputDecoder,const TensorShape & rOutputShape,Encoder<float> & rOutputEncoder,const TensorShape & rWeightsShape,Decoder<float> & rWeightDecoder,Decoder<float> * pBiasDecoder,const bool biasEnabled,const unsigned int K,const bool transposeWeights)15 void FullyConnected(const TensorShape& rInputShape,
16 Decoder<float>& rInputDecoder,
17 const TensorShape& rOutputShape,
18 Encoder<float>& rOutputEncoder,
19 const TensorShape& rWeightsShape,
20 Decoder<float>& rWeightDecoder,
21 Decoder<float>* pBiasDecoder,
22 const bool biasEnabled,
23 const unsigned int K,
24 const bool transposeWeights)
25 {
26 // Perform FullyConnected implementation
27 unsigned int outputSize = rOutputShape[1];
28
29 const std::vector<float> decodedInputs = rInputDecoder.DecodeTensor(rInputShape);
30 const std::vector<float> decodedWeights = rWeightDecoder.DecodeTensor(rWeightsShape);
31
32 const TensorShape biasShape{outputSize};
33
34 ARMNN_ASSERT(!biasEnabled || pBiasDecoder != nullptr);
35 const std::vector<float> decodedBiases = biasEnabled ? pBiasDecoder->DecodeTensor(biasShape) : std::vector<float>();
36
37
38 for (unsigned int n = 0; n < rInputShape[0]; n++)
39 {
40 for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++)
41 {
42 float outval = 0.f;
43
44 for (unsigned int channelInput = 0; channelInput < K; channelInput++)
45 {
46 float weight;
47 if (transposeWeights)
48 {
49 weight = decodedWeights[channelOutput * K + channelInput];
50 }
51 else
52 {
53 weight = decodedWeights[channelInput * outputSize + channelOutput];
54 }
55
56 outval += weight * decodedInputs[n * K + channelInput];
57 }
58
59 if (biasEnabled)
60 {
61 outval += decodedBiases[channelOutput];
62 }
63
64 rOutputEncoder[n * outputSize + channelOutput];
65 rOutputEncoder.Set(outval);
66 }
67 }
68 }
69
70 } //namespace armnn
71