1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_INDIRECT_CONV2D_ADDRESS_PRECALCULATION_FIXTURE
25 #define ARM_COMPUTE_TEST_INDIRECT_CONV2D_ADDRESS_PRECALCULATION_FIXTURE
26 
27 #include "arm_compute/core/TensorShape.h"
28 #include "arm_compute/core/Types.h"
29 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
30 #include "tests/Globals.h"
31 #include "tests/framework/Fixture.h"
32 #include "tests/validation/Helpers.h"
33 #include "tests/validation/reference/IndirectConv2dAddressPrecalculation.h"
34 
35 namespace arm_compute
36 {
37 namespace test
38 {
39 namespace validation
40 {
41 using namespace arm_compute::misc::shape_calculator;
42 
43 template <typename TensorType, typename AccessorType, typename OperatorType>
44 class IndirectConv2dAddressPrecalculationValidationFixture : public framework::Fixture
45 {
46 public:
47     template <typename...>
setup(unsigned int src_w,unsigned int src_h,unsigned int src_b,unsigned int wei_w,unsigned int wei_h,unsigned int pad,unsigned int stride,unsigned int m0)48     void setup(unsigned int src_w,
49                unsigned int src_h,
50                unsigned int src_b,
51                unsigned int wei_w,
52                unsigned int wei_h,
53                unsigned int pad,
54                unsigned int stride,
55                unsigned int m0)
56     {
57         DirectConvComputeKernelInfo desc;
58         desc.m0                         = m0;
59         desc.n0                         = 1;     // Not used by the kernel
60         desc.k0                         = 1;     // Not used by the kernel
61         desc.export_weights_to_cl_image = false; // Not used by the kernel
62 
63         const PadStrideInfo conv_info(stride, stride, pad, pad);
64 
65         const TensorShape shape_conv_src(23, // The input channels are not used by the kernel
66                                          src_w,
67                                          src_h,
68                                          src_b);
69 
70         const TensorShape shape_conv_wei(23, // The input channels are not used by the kernel
71                                          wei_w,
72                                          wei_h,
73                                          23 // The output channels are not used by the kernel
74                                         );
75 
76         // The result of the kernel does not change with the datatype. Hence, we can fix it to Fp16 for validation purposes
77         const DataType data_type = DataType::F16;
78 
79         _target    = compute_target(shape_conv_src, shape_conv_wei, data_type, conv_info, desc);
80         _reference = compute_reference(shape_conv_src, shape_conv_wei, data_type, conv_info, desc);
81     }
82 
83 protected:
compute_target(TensorShape shape_conv_src,TensorShape shape_conv_wei,DataType data_type,const PadStrideInfo & conv_info,const DirectConvComputeKernelInfo & desc)84     TensorType compute_target(TensorShape shape_conv_src, TensorShape shape_conv_wei, DataType data_type, const PadStrideInfo &conv_info, const DirectConvComputeKernelInfo &desc)
85     {
86         TensorInfo src_conv_info(shape_conv_src, 1, data_type, DataLayout::NHWC);
87         TensorInfo wei_conv_info(shape_conv_wei, 1, data_type, DataLayout::NHWC);
88         TensorType dst;
89 
90         // The output tensor will be auto-initialized within the function
91 
92         // Create and configure function
93         OperatorType func;
94         func.configure(&src_conv_info, &wei_conv_info, dst.info(), conv_info, desc);
95 
96         add_padding_x({ &dst });
97 
98         // Allocate tensors
99         dst.allocator()->allocate();
100 
101         // Compute GEMM LHS matrix reshape function
102         ITensorPack tensors = { { ACL_DST, &dst } };
103         func.run(tensors);
104 
105         return dst;
106     }
107 
compute_reference(TensorShape shape_conv_src,TensorShape shape_conv_wei,DataType data_type,const PadStrideInfo & conv_info,const DirectConvComputeKernelInfo & desc)108     SimpleTensor<int32_t> compute_reference(TensorShape shape_conv_src, TensorShape shape_conv_wei, DataType data_type, const PadStrideInfo &conv_info, const DirectConvComputeKernelInfo &desc)
109     {
110         ARM_COMPUTE_UNUSED(data_type);
111         TensorShape shape_out         = compute_indirect_buffer_shape(shape_conv_src, DataLayout::NHWC, shape_conv_wei, conv_info, desc);
112         TensorShape output_conv_shape = compute_deep_convolution_shape(shape_conv_src, DataLayout::NHWC, shape_conv_wei, conv_info);
113 
114         return reference::indirect_conv2d_addr_precalculation(shape_conv_src, shape_conv_wei, output_conv_shape, shape_out, conv_info);
115     }
116 
117     TensorType            _target{};
118     SimpleTensor<int32_t> _reference{};
119 };
120 } // namespace validation
121 } // namespace test
122 } // namespace arm_compute
123 #endif /* ARM_COMPUTE_TEST_INDIRECT_CONV2D_ADDRESS_PRECALCULATION_FIXTURE */