xref: /aosp_15_r20/external/armnn/src/armnnUtils/DataLayoutIndexed.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/DataLayoutIndexed.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker namespace armnnUtils
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker 
DataLayoutIndexed(armnn::DataLayout dataLayout)13*89c4ff92SAndroid Build Coastguard Worker DataLayoutIndexed::DataLayoutIndexed(armnn::DataLayout dataLayout)
14*89c4ff92SAndroid Build Coastguard Worker     : m_DataLayout(dataLayout)
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker     switch (dataLayout)
17*89c4ff92SAndroid Build Coastguard Worker     {
18*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataLayout::NHWC:
19*89c4ff92SAndroid Build Coastguard Worker             m_ChannelsIndex = 3;
20*89c4ff92SAndroid Build Coastguard Worker             m_HeightIndex   = 1;
21*89c4ff92SAndroid Build Coastguard Worker             m_WidthIndex    = 2;
22*89c4ff92SAndroid Build Coastguard Worker             break;
23*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataLayout::NCHW:
24*89c4ff92SAndroid Build Coastguard Worker             m_ChannelsIndex = 1;
25*89c4ff92SAndroid Build Coastguard Worker             m_HeightIndex   = 2;
26*89c4ff92SAndroid Build Coastguard Worker             m_WidthIndex    = 3;
27*89c4ff92SAndroid Build Coastguard Worker             break;
28*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataLayout::NDHWC:
29*89c4ff92SAndroid Build Coastguard Worker             m_DepthIndex    = 1;
30*89c4ff92SAndroid Build Coastguard Worker             m_HeightIndex   = 2;
31*89c4ff92SAndroid Build Coastguard Worker             m_WidthIndex    = 3;
32*89c4ff92SAndroid Build Coastguard Worker             m_ChannelsIndex = 4;
33*89c4ff92SAndroid Build Coastguard Worker             break;
34*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataLayout::NCDHW:
35*89c4ff92SAndroid Build Coastguard Worker             m_ChannelsIndex = 1;
36*89c4ff92SAndroid Build Coastguard Worker             m_DepthIndex    = 2;
37*89c4ff92SAndroid Build Coastguard Worker             m_HeightIndex   = 3;
38*89c4ff92SAndroid Build Coastguard Worker             m_WidthIndex    = 4;
39*89c4ff92SAndroid Build Coastguard Worker             break;
40*89c4ff92SAndroid Build Coastguard Worker         default:
41*89c4ff92SAndroid Build Coastguard Worker             throw armnn::InvalidArgumentException("Unknown DataLayout value: " +
42*89c4ff92SAndroid Build Coastguard Worker                                                   std::to_string(static_cast<int>(dataLayout)));
43*89c4ff92SAndroid Build Coastguard Worker     }
44*89c4ff92SAndroid Build Coastguard Worker }
45*89c4ff92SAndroid Build Coastguard Worker 
operator ==(const DataLayout & dataLayout,const DataLayoutIndexed & indexed)46*89c4ff92SAndroid Build Coastguard Worker bool operator==(const DataLayout& dataLayout, const DataLayoutIndexed& indexed)
47*89c4ff92SAndroid Build Coastguard Worker {
48*89c4ff92SAndroid Build Coastguard Worker     return dataLayout == indexed.GetDataLayout();
49*89c4ff92SAndroid Build Coastguard Worker }
50*89c4ff92SAndroid Build Coastguard Worker 
operator ==(const DataLayoutIndexed & indexed,const DataLayout & dataLayout)51*89c4ff92SAndroid Build Coastguard Worker bool operator==(const DataLayoutIndexed& indexed, const DataLayout& dataLayout)
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker     return indexed.GetDataLayout() == dataLayout;
54*89c4ff92SAndroid Build Coastguard Worker }
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnUtils
57