xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/fixtures/QLSTMLayerNormalizationFixture.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2020-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE
25 #define ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE
26 
27 #include "arm_compute/core/TensorShape.h"
28 #include "arm_compute/core/Types.h"
29 #include "tests/AssetsLibrary.h"
30 #include "tests/Globals.h"
31 #include "tests/IAccessor.h"
32 #include "tests/framework/Asserts.h"
33 #include "tests/framework/Fixture.h"
34 #include "tests/validation/Helpers.h"
35 #include "tests/validation/reference/QLSTMLayerNormalization.h"
36 
37 namespace arm_compute
38 {
39 namespace test
40 {
41 namespace validation
42 {
43 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
44 class QLSTMLayerNormalizationValidationFixture : public framework::Fixture
45 {
46 public:
47     template <typename...>
setup(TensorShape input_shape,TensorShape weight_shape,TensorShape bias_shape,DataType data_type,QuantizationInfo weight_qinfo)48     void setup(TensorShape input_shape, TensorShape weight_shape, TensorShape bias_shape, DataType data_type, QuantizationInfo weight_qinfo)
49     {
50         ARM_COMPUTE_ERROR_ON(data_type != DataType::QSYMM16);
51 
52         _data_type = data_type;
53         _qinfo     = weight_qinfo;
54 
55         _target    = compute_target(input_shape, weight_shape, bias_shape);
56         _reference = compute_reference(input_shape, weight_shape, bias_shape);
57     }
58 
59 protected:
60     template <typename InputType, typename BiasType>
fill(InputType && input_tensor,InputType && weight_tensor,BiasType && bias_tensor)61     void fill(InputType &&input_tensor, InputType &&weight_tensor, BiasType &&bias_tensor)
62     {
63         switch(_data_type)
64         {
65             case DataType::QSYMM16:
66             {
67                 // Value ranges are based on reference implementation's test case.
68                 constexpr int16_t input_min  = -1000;
69                 constexpr int16_t input_max  = 1000;
70                 constexpr int16_t weight_min = 19000;
71                 constexpr int16_t weight_max = 27000;
72                 constexpr int32_t bias_min   = -16000000;
73                 constexpr int32_t bias_max   = -13000000;
74 
75                 std::uniform_int_distribution<> input_distribution(input_min, input_max);
76                 std::uniform_int_distribution<> weight_distribution(weight_min, weight_max);
77                 std::uniform_int_distribution<> bias_distribution(bias_min, bias_max);
78 
79                 library->fill(input_tensor, input_distribution, 0);
80                 library->fill(weight_tensor, weight_distribution, 0);
81                 library->fill(bias_tensor, bias_distribution, 0);
82                 break;
83             }
84             default:
85                 ARM_COMPUTE_ERROR("non-supported data type");
86                 break;
87         }
88     }
89 
allocate_tensors(const std::vector<TensorType * > & tensors)90     void allocate_tensors(const std::vector<TensorType *> &tensors)
91     {
92         for(auto t : tensors)
93         {
94             ARM_COMPUTE_ASSERT(t->info()->is_resizable());
95             t->allocator()->allocate();
96             ARM_COMPUTE_ASSERT(!t->info()->is_resizable());
97         }
98     }
99 
compute_target(const TensorShape & input_shape,const TensorShape & weight_shape,const TensorShape & bias_shape)100     TensorType compute_target(const TensorShape &input_shape, const TensorShape &weight_shape, const TensorShape &bias_shape)
101     {
102         TensorType input  = create_tensor<TensorType>(input_shape, _data_type, 1);
103         TensorType weight = create_tensor<TensorType>(weight_shape, _data_type, 1, _qinfo);
104         TensorType bias   = create_tensor<TensorType>(bias_shape, DataType::S32, 1);
105         TensorType output = create_tensor<TensorType>(input_shape, _data_type, 1);
106 
107         FunctionType fn;
108         fn.configure(&input, &output, &weight, &bias);
109         allocate_tensors({ &input, &weight, &bias, &output });
110         fill(AccessorType(input), AccessorType(weight), AccessorType(bias));
111         fn.run();
112 
113         return output;
114     }
115 
compute_reference(const TensorShape & input_shape,const TensorShape & weight_shape,const TensorShape & bias_shape)116     SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weight_shape, const TensorShape &bias_shape)
117     {
118         // Create reference
119         SimpleTensor<T>       input{ input_shape, _data_type, 1 };
120         SimpleTensor<T>       weight{ weight_shape, _data_type, 1, _qinfo };
121         SimpleTensor<int32_t> bias{ bias_shape, DataType::S32, 1 };
122 
123         // Fill reference
124         fill(input, weight, bias);
125 
126         return reference::qlstm_layer_normalization(input, weight, bias);
127     }
128 
129     TensorType       _target{};
130     SimpleTensor<T>  _reference{};
131     DataType         _data_type{};
132     QuantizationInfo _qinfo{};
133 };
134 } // namespace validation
135 } // namespace test
136 } // namespace arm_compute
137 
138 #endif /* ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE */
139