xref: /aosp_15_r20/external/armnn/src/backends/neon/workloads/NeonChannelShuffleWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "NeonChannelShuffleWorkload.hpp"
7 #include "NeonWorkloadUtils.hpp"
8 
9 #include <aclCommon/ArmComputeTensorHandle.hpp>
10 #include <aclCommon/ArmComputeTensorUtils.hpp>
11 
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13 
14 namespace armnn
15 {
16 
NeonChannelShuffleValidate(const TensorInfo & input,const TensorInfo & output,const ChannelShuffleDescriptor & descriptor)17 arm_compute::Status NeonChannelShuffleValidate(const TensorInfo& input,
18                                                const TensorInfo& output,
19                                                const ChannelShuffleDescriptor& descriptor)
20 {
21     arm_compute::TensorInfo aclInputInfo  = armcomputetensorutils::BuildArmComputeTensorInfo(input);
22     arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
23 
24     // In Arm NN and in NNAPI, channel shuffle implementation is datalayout agnostic and it has axis as a parameter.
25     // The channel shuffle Implementation for Neon is dependent on datalayout and does not have axis as a parameter,
26     // it only supports channel shuffle for 4D tensors in dimension C (1 or 3).
27     arm_compute::DataLayout aclDataLayout;
28     if (input.GetNumDimensions() == 4)
29     {
30         switch (descriptor.m_Axis)
31         {
32             case 1:
33                 aclDataLayout = ConvertDataLayout(armnn::DataLayout::NCHW);
34                 break;
35             case 3:
36                 aclDataLayout = ConvertDataLayout(armnn::DataLayout::NHWC);
37                 break;
38             default:
39                 return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported axis"};
40         }
41         aclInputInfo.set_data_layout(aclDataLayout);
42         aclOutputInfo.set_data_layout(aclDataLayout);
43         return arm_compute::NEChannelShuffleLayer::validate(&aclInputInfo, &aclOutputInfo, descriptor.m_NumGroups);
44     }
45     else
46     {
47         return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported number of dimensions"};
48     }
49 }
50 
NeonChannelShuffleWorkload(const ChannelShuffleQueueDescriptor & descriptor,const WorkloadInfo & info)51 NeonChannelShuffleWorkload::NeonChannelShuffleWorkload(const ChannelShuffleQueueDescriptor& descriptor,
52                                                        const WorkloadInfo& info)
53     : NeonBaseWorkload<ChannelShuffleQueueDescriptor>(descriptor, info)
54 {
55     // Report Profiling Details
56     ARMNN_REPORT_PROFILING_WORKLOAD_DESC("NeonChannelShufflenWorkload_Construct",
57                                          descriptor.m_Parameters,
58                                          info,
59                                          this->GetGuid());
60 
61     m_Data.ValidateInputsOutputs("NeonChannelShuffleWorkload", 1, 1);
62 
63     arm_compute::ITensor& input  = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
64     arm_compute::ITensor& output = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
65 
66     // In Arm NN and in NNAPI, channel shuffle implementation is datalayout agnostic and it has axis as a parameter.
67     // The channel shuffle Implementation for Neon is dependent on datalayout and does not have axis as a parameter,
68     // it only supports channel shuffle for 4D tensors in dimension C (1 or 3).
69     arm_compute::DataLayout aclDataLayout;
70     switch (descriptor.m_Parameters.m_Axis)
71     {
72         case 1:
73             aclDataLayout = ConvertDataLayout(armnn::DataLayout::NCHW);
74             break;
75         case 3:
76             aclDataLayout = ConvertDataLayout(armnn::DataLayout::NHWC);
77             break;
78         default:
79             ARMNN_ASSERT_MSG(false, "Unsupported axis");
80             break;
81     }
82     input.info()->set_data_layout(aclDataLayout);
83     output.info()->set_data_layout(aclDataLayout);
84 
85     m_ChannelShuffleLayer.configure(&input, &output, descriptor.m_Parameters.m_NumGroups);
86 }
87 
Execute() const88 void NeonChannelShuffleWorkload::Execute() const
89 {
90     ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonChannelShuffleWorkload_Execute", this->GetGuid());
91     m_ChannelShuffleLayer.run();
92 }
93 
94 } // namespace armnn
95