xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/RefChannelShuffleWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <armnn/backends/ITensorHandleFactory.hpp>
7 #include <armnnUtils/Transpose.hpp>
8 #include "RefChannelShuffleWorkload.hpp"
9 #include "RefWorkloadUtils.hpp"
10 #include "Profiling.hpp"
11 #include "Decoders.hpp"
12 #include "Encoders.hpp"
13 
14 namespace armnn
15 {
Execute() const16 void RefChannelShuffleWorkload::Execute() const
17 {
18     Execute(m_Data.m_Inputs, m_Data.m_Outputs);
19 }
20 
ExecuteAsync(ExecutionData & executionData)21 void RefChannelShuffleWorkload::ExecuteAsync(ExecutionData& executionData)
22 {
23     WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
24     Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
25 }
26 
27 // Reference implementation for channel shuffle taken from
28 // https://android.googlesource.com/platform/frameworks/ml/+/refs/heads/master/nn/common/operations/ChannelShuffle.cpp
Execute(std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs) const29 void RefChannelShuffleWorkload::Execute(std::vector<ITensorHandle*> inputs,
30                                         std::vector<ITensorHandle*> outputs) const
31 {
32     ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefChannelShuffleWorkload_Execute");
33 
34     const TensorInfo& inputInfo  = GetTensorInfo(inputs[0]);
35     const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
36     std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputInfo, inputs[0]->Map());
37     Decoder<float>& decoder = *decoderPtr;
38 
39     std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
40     Encoder<float>& encoder = *encoderPtr;
41 
42     auto getNumberOfElements = [](const TensorShape& tensorShape,uint32_t startAxis, uint32_t lastAxis)
43     {
44         uint32_t count = 1;
45         for (uint32_t i = startAxis; i < lastAxis; i++)
46         {
47             count *= tensorShape[i];
48         }
49         return count;
50     };
51     const TensorShape tensorShape = GetTensorInfo(inputs[0]).GetShape();
52     uint32_t channelsAxis = m_Data.m_Parameters.m_Axis; // channelsAxis to perform channel shuffle on
53 
54     const uint32_t numGroups = m_Data.m_Parameters.m_NumGroups;
55     const uint32_t groupSize = tensorShape[channelsAxis] / numGroups;
56 
57     uint32_t outerSize = getNumberOfElements(tensorShape, 0, channelsAxis);
58     uint32_t innerSize = getNumberOfElements(tensorShape, channelsAxis + 1, tensorShape.GetNumDimensions());
59 
60     for (uint32_t outer = 0; outer < outerSize; ++outer)
61     {
62         for (uint32_t inner = 0; inner < innerSize; ++inner)
63         {
64             uint32_t decoderStep1 = outer * tensorShape[channelsAxis] * innerSize + inner;
65             decoder += decoderStep1;
66             uint32_t encoderStep1 = outer * tensorShape[channelsAxis] * innerSize + inner;
67             encoder += encoderStep1;
68             for (uint32_t i = 0; i < groupSize; i++)
69             {
70                 for (uint32_t j = 0; j < numGroups; j++, encoder += innerSize, encoderStep1 += innerSize)
71                 {
72                     decoder += innerSize * (i + j * groupSize);
73                     float decoded = decoder.Get();
74                     encoder.Set(decoded);
75                     decoder -= innerSize * (i + j * groupSize);
76                 }
77             }
78             decoder -= decoderStep1;
79             encoder -= encoderStep1;
80         }
81     }
82 }
83 }