xref: /aosp_15_r20/external/android-nn-driver/test/Dilation.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "DriverTestHelpers.hpp"
9 
10 #include <armnn/StrategyBase.hpp>
11 #include <armnn/utility/IgnoreUnused.hpp>
12 
13 #include <numeric>
14 
15 using namespace armnn;
16 using namespace driverTestHelpers;
17 
18 struct DilationTestOptions
19 {
DilationTestOptionsDilationTestOptions20     DilationTestOptions() :
21         m_IsDepthwiseConvolution{false},
22         m_IsPaddingExplicit{false},
23         m_HasDilation{false}
24     {}
25 
26     ~DilationTestOptions() = default;
27 
28     bool m_IsDepthwiseConvolution;
29     bool m_IsPaddingExplicit;
30     bool m_HasDilation;
31 };
32 
33 class DilationTestVisitor : public StrategyBase<ThrowingStrategy>
34 {
35 public:
DilationTestVisitor()36     DilationTestVisitor() :
37         DilationTestVisitor(1u, 1u)
38     {}
39 
DilationTestVisitor(uint32_t expectedDilationX,uint32_t expectedDilationY)40     DilationTestVisitor(uint32_t expectedDilationX, uint32_t expectedDilationY) :
41         m_ExpectedDilationX{expectedDilationX},
42         m_ExpectedDilationY{expectedDilationY}
43     {}
44 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)45     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
46                          const armnn::BaseDescriptor& descriptor,
47                          const std::vector<armnn::ConstTensor>& constants,
48                          const char* name,
49                          const armnn::LayerBindingId id = 0) override
50     {
51         armnn::IgnoreUnused(layer, constants, id, name);
52         switch (layer->GetType())
53         {
54             case armnn::LayerType::Constant:
55                 break;
56             case armnn::LayerType::Convolution2d:
57             {
58                 CheckDilationParams(static_cast<const armnn::Convolution2dDescriptor&>(descriptor));
59                 break;
60             }
61             case armnn::LayerType::DepthwiseConvolution2d:
62             {
63                 CheckDilationParams(static_cast<const armnn::DepthwiseConvolution2dDescriptor&>(descriptor));
64                 break;
65             }
66             default:
67             {
68                 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
69             }
70         }
71     }
72 
73 private:
74     uint32_t m_ExpectedDilationX;
75     uint32_t m_ExpectedDilationY;
76 
77     template<typename ConvolutionDescriptor>
CheckDilationParams(const ConvolutionDescriptor & descriptor)78     void CheckDilationParams(const ConvolutionDescriptor& descriptor)
79     {
80         CHECK_EQ(descriptor.m_DilationX, m_ExpectedDilationX);
81         CHECK_EQ(descriptor.m_DilationY, m_ExpectedDilationY);
82     }
83 };
84 
85 template<typename HalPolicy>
DilationTestImpl(const DilationTestOptions & options)86 void DilationTestImpl(const DilationTestOptions& options)
87 {
88     using HalModel         = typename HalPolicy::Model;
89     using HalOperationType = typename HalPolicy::OperationType;
90 
91     const armnn::Compute backend = armnn::Compute::CpuRef;
92     auto driver = std::make_unique<ArmnnDriver>(DriverOptions(backend, false));
93     HalModel model = {};
94 
95     // add operands
96     std::vector<float> weightData(9, 1.0f);
97     std::vector<float> biasData(1, 0.0f );
98 
99     // input
100     AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 3, 3, 1});
101 
102     // weights & biases
103     AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 3, 3, 1}, weightData.data());
104     AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{1}, biasData.data());
105 
106     uint32_t numInputs = 3u;
107     // padding
108     if (options.m_IsPaddingExplicit)
109     {
110         AddIntOperand<HalPolicy>(model, 1);
111         AddIntOperand<HalPolicy>(model, 1);
112         AddIntOperand<HalPolicy>(model, 1);
113         AddIntOperand<HalPolicy>(model, 1);
114         numInputs += 4;
115     }
116     else
117     {
118         AddIntOperand<HalPolicy>(model, android::nn::kPaddingSame);
119         numInputs += 1;
120     }
121 
122     AddIntOperand<HalPolicy>(model, 2); // stride x
123     AddIntOperand<HalPolicy>(model, 2); // stride y
124     numInputs += 2;
125 
126     if (options.m_IsDepthwiseConvolution)
127     {
128         AddIntOperand<HalPolicy>(model, 1); // depth multiplier
129         numInputs++;
130     }
131 
132     AddIntOperand<HalPolicy>(model, 0); // no activation
133     numInputs += 1;
134 
135     // dilation
136     if (options.m_HasDilation)
137     {
138         AddBoolOperand<HalPolicy>(model, false); // default data layout
139 
140         AddIntOperand<HalPolicy>(model, 2); // dilation X
141         AddIntOperand<HalPolicy>(model, 2); // dilation Y
142 
143         numInputs += 3;
144     }
145 
146     // output
147     AddOutputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 1, 1, 1});
148 
149     // set up the convolution operation
150     model.operations.resize(1);
151     model.operations[0].type = options.m_IsDepthwiseConvolution ?
152         HalOperationType::DEPTHWISE_CONV_2D : HalOperationType::CONV_2D;
153 
154     std::vector<uint32_t> inputs(numInputs);
155     std::iota(inputs.begin(), inputs.end(), 0u);
156     std::vector<uint32_t> outputs = { numInputs };
157 
158     model.operations[0].inputs  = hidl_vec<uint32_t>(inputs);
159     model.operations[0].outputs = hidl_vec<uint32_t>(outputs);
160 
161     // convert model
162     ConversionData data({backend});
163     data.m_Network = armnn::INetwork::Create();
164     data.m_OutputSlotForOperand = std::vector<IOutputSlot*>(model.operands.size(), nullptr);
165 
166     bool ok = HalPolicy::ConvertOperation(model.operations[0], model, data);
167     DOCTEST_CHECK(ok);
168 
169     // check if dilation params are as expected
170     DilationTestVisitor visitor = options.m_HasDilation ? DilationTestVisitor(2, 2) : DilationTestVisitor();
171     data.m_Network->ExecuteStrategy(visitor);
172 }
173