1*c217d954SCole Faust /* 2*c217d954SCole Faust * Copyright (c) 2018-2020 Arm Limited. 3*c217d954SCole Faust * 4*c217d954SCole Faust * SPDX-License-Identifier: MIT 5*c217d954SCole Faust * 6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy 7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to 8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the 9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is 11*c217d954SCole Faust * furnished to do so, subject to the following conditions: 12*c217d954SCole Faust * 13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all 14*c217d954SCole Faust * copies or substantial portions of the Software. 15*c217d954SCole Faust * 16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22*c217d954SCole Faust * SOFTWARE. 23*c217d954SCole Faust */ 24*c217d954SCole Faust #ifndef ARM_COMPUTE_TEST_SPLIT_DATASET 25*c217d954SCole Faust #define ARM_COMPUTE_TEST_SPLIT_DATASET 26*c217d954SCole Faust 27*c217d954SCole Faust #include "utils/TypePrinter.h" 28*c217d954SCole Faust 29*c217d954SCole Faust #include "arm_compute/core/Types.h" 30*c217d954SCole Faust 31*c217d954SCole Faust namespace arm_compute 32*c217d954SCole Faust { 33*c217d954SCole Faust namespace test 34*c217d954SCole Faust { 35*c217d954SCole Faust namespace datasets 36*c217d954SCole Faust { 37*c217d954SCole Faust class SplitDataset 38*c217d954SCole Faust { 39*c217d954SCole Faust public: 40*c217d954SCole Faust using type = std::tuple<TensorShape, unsigned int, unsigned int>; 41*c217d954SCole Faust 42*c217d954SCole Faust struct iterator 43*c217d954SCole Faust { iteratoriterator44*c217d954SCole Faust iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it, 45*c217d954SCole Faust std::vector<unsigned int>::const_iterator axis_values_it, 46*c217d954SCole Faust std::vector<unsigned int>::const_iterator splits_values_it) 47*c217d954SCole Faust : _tensor_shapes_it{ std::move(tensor_shapes_it) }, 48*c217d954SCole Faust _axis_values_it{ std::move(axis_values_it) }, 49*c217d954SCole Faust _splits_values_it{ std::move(splits_values_it) } 50*c217d954SCole Faust { 51*c217d954SCole Faust } 52*c217d954SCole Faust descriptioniterator53*c217d954SCole Faust std::string description() const 54*c217d954SCole Faust { 55*c217d954SCole Faust std::stringstream description; 56*c217d954SCole Faust description << "Shape=" << *_tensor_shapes_it << ":"; 57*c217d954SCole Faust description << "Axis=" << *_axis_values_it << ":"; 58*c217d954SCole Faust description << "Splits=" << *_splits_values_it << ":"; 59*c217d954SCole Faust return description.str(); 60*c217d954SCole Faust } 61*c217d954SCole Faust 62*c217d954SCole Faust SplitDataset::type operator*() const 63*c217d954SCole Faust { 64*c217d954SCole Faust return std::make_tuple(*_tensor_shapes_it, *_axis_values_it, *_splits_values_it); 65*c217d954SCole Faust } 66*c217d954SCole Faust 67*c217d954SCole Faust iterator &operator++() 68*c217d954SCole Faust { 69*c217d954SCole Faust ++_tensor_shapes_it; 70*c217d954SCole Faust ++_axis_values_it; 71*c217d954SCole Faust ++_splits_values_it; 72*c217d954SCole Faust return *this; 73*c217d954SCole Faust } 74*c217d954SCole Faust 75*c217d954SCole Faust private: 76*c217d954SCole Faust std::vector<TensorShape>::const_iterator _tensor_shapes_it; 77*c217d954SCole Faust std::vector<unsigned int>::const_iterator _axis_values_it; 78*c217d954SCole Faust std::vector<unsigned int>::const_iterator _splits_values_it; 79*c217d954SCole Faust }; 80*c217d954SCole Faust begin()81*c217d954SCole Faust iterator begin() const 82*c217d954SCole Faust { 83*c217d954SCole Faust return iterator(_tensor_shapes.begin(), _axis_values.begin(), _splits_values.begin()); 84*c217d954SCole Faust } 85*c217d954SCole Faust size()86*c217d954SCole Faust int size() const 87*c217d954SCole Faust { 88*c217d954SCole Faust return std::min(_tensor_shapes.size(), std::min(_axis_values.size(), _splits_values.size())); 89*c217d954SCole Faust } 90*c217d954SCole Faust add_config(TensorShape shape,unsigned int axis,unsigned int splits)91*c217d954SCole Faust void add_config(TensorShape shape, unsigned int axis, unsigned int splits) 92*c217d954SCole Faust { 93*c217d954SCole Faust _tensor_shapes.emplace_back(std::move(shape)); 94*c217d954SCole Faust _axis_values.emplace_back(axis); 95*c217d954SCole Faust _splits_values.emplace_back(splits); 96*c217d954SCole Faust } 97*c217d954SCole Faust 98*c217d954SCole Faust protected: 99*c217d954SCole Faust SplitDataset() = default; 100*c217d954SCole Faust SplitDataset(SplitDataset &&) = default; 101*c217d954SCole Faust 102*c217d954SCole Faust private: 103*c217d954SCole Faust std::vector<TensorShape> _tensor_shapes{}; 104*c217d954SCole Faust std::vector<unsigned int> _axis_values{}; 105*c217d954SCole Faust std::vector<unsigned int> _splits_values{}; 106*c217d954SCole Faust }; 107*c217d954SCole Faust 108*c217d954SCole Faust class SmallSplitDataset final : public SplitDataset 109*c217d954SCole Faust { 110*c217d954SCole Faust public: SmallSplitDataset()111*c217d954SCole Faust SmallSplitDataset() 112*c217d954SCole Faust { 113*c217d954SCole Faust add_config(TensorShape(128U), 0U, 4U); 114*c217d954SCole Faust add_config(TensorShape(6U, 3U, 4U), 2U, 2U); 115*c217d954SCole Faust add_config(TensorShape(27U, 14U, 2U), 1U, 2U); 116*c217d954SCole Faust add_config(TensorShape(64U, 32U, 4U, 6U), 3U, 3U); 117*c217d954SCole Faust } 118*c217d954SCole Faust }; 119*c217d954SCole Faust 120*c217d954SCole Faust class LargeSplitDataset final : public SplitDataset 121*c217d954SCole Faust { 122*c217d954SCole Faust public: LargeSplitDataset()123*c217d954SCole Faust LargeSplitDataset() 124*c217d954SCole Faust { 125*c217d954SCole Faust add_config(TensorShape(512U), 0U, 8U); 126*c217d954SCole Faust add_config(TensorShape(128U, 64U, 8U), 2U, 2U); 127*c217d954SCole Faust add_config(TensorShape(128U, 64U, 8U, 2U), 1U, 2U); 128*c217d954SCole Faust add_config(TensorShape(128U, 64U, 32U, 4U), 3U, 4U); 129*c217d954SCole Faust } 130*c217d954SCole Faust }; 131*c217d954SCole Faust 132*c217d954SCole Faust class SplitShapesDataset 133*c217d954SCole Faust { 134*c217d954SCole Faust public: 135*c217d954SCole Faust using type = std::tuple<TensorShape, unsigned int, std::vector<TensorShape>>; 136*c217d954SCole Faust 137*c217d954SCole Faust struct iterator 138*c217d954SCole Faust { iteratoriterator139*c217d954SCole Faust iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it, 140*c217d954SCole Faust std::vector<unsigned int>::const_iterator axis_values_it, 141*c217d954SCole Faust std::vector<std::vector<TensorShape>>::const_iterator split_shapes_values_it) 142*c217d954SCole Faust : _tensor_shapes_it{ std::move(tensor_shapes_it) }, 143*c217d954SCole Faust _axis_values_it{ std::move(axis_values_it) }, 144*c217d954SCole Faust _split_shapes_values_it{ std::move(split_shapes_values_it) } 145*c217d954SCole Faust { 146*c217d954SCole Faust } 147*c217d954SCole Faust descriptioniterator148*c217d954SCole Faust std::string description() const 149*c217d954SCole Faust { 150*c217d954SCole Faust std::stringstream description; 151*c217d954SCole Faust description << "Shape=" << *_tensor_shapes_it << ":"; 152*c217d954SCole Faust description << "Axis=" << *_axis_values_it << ":"; 153*c217d954SCole Faust description << "Split shapes=" << *_split_shapes_values_it << ":"; 154*c217d954SCole Faust return description.str(); 155*c217d954SCole Faust } 156*c217d954SCole Faust 157*c217d954SCole Faust SplitShapesDataset::type operator*() const 158*c217d954SCole Faust { 159*c217d954SCole Faust return std::make_tuple(*_tensor_shapes_it, *_axis_values_it, *_split_shapes_values_it); 160*c217d954SCole Faust } 161*c217d954SCole Faust 162*c217d954SCole Faust iterator &operator++() 163*c217d954SCole Faust { 164*c217d954SCole Faust ++_tensor_shapes_it; 165*c217d954SCole Faust ++_axis_values_it; 166*c217d954SCole Faust ++_split_shapes_values_it; 167*c217d954SCole Faust return *this; 168*c217d954SCole Faust } 169*c217d954SCole Faust 170*c217d954SCole Faust private: 171*c217d954SCole Faust std::vector<TensorShape>::const_iterator _tensor_shapes_it; 172*c217d954SCole Faust std::vector<unsigned int>::const_iterator _axis_values_it; 173*c217d954SCole Faust std::vector<std::vector<TensorShape>>::const_iterator _split_shapes_values_it; 174*c217d954SCole Faust }; 175*c217d954SCole Faust begin()176*c217d954SCole Faust iterator begin() const 177*c217d954SCole Faust { 178*c217d954SCole Faust return iterator(_tensor_shapes.begin(), _axis_values.begin(), _split_shapes_values.begin()); 179*c217d954SCole Faust } 180*c217d954SCole Faust size()181*c217d954SCole Faust int size() const 182*c217d954SCole Faust { 183*c217d954SCole Faust return std::min(_tensor_shapes.size(), std::min(_axis_values.size(), _split_shapes_values.size())); 184*c217d954SCole Faust } 185*c217d954SCole Faust add_config(TensorShape shape,unsigned int axis,std::vector<TensorShape> split_shapes)186*c217d954SCole Faust void add_config(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes) 187*c217d954SCole Faust { 188*c217d954SCole Faust _tensor_shapes.emplace_back(std::move(shape)); 189*c217d954SCole Faust _axis_values.emplace_back(axis); 190*c217d954SCole Faust _split_shapes_values.emplace_back(split_shapes); 191*c217d954SCole Faust } 192*c217d954SCole Faust 193*c217d954SCole Faust protected: 194*c217d954SCole Faust SplitShapesDataset() = default; 195*c217d954SCole Faust SplitShapesDataset(SplitShapesDataset &&) = default; 196*c217d954SCole Faust 197*c217d954SCole Faust private: 198*c217d954SCole Faust std::vector<TensorShape> _tensor_shapes{}; 199*c217d954SCole Faust std::vector<unsigned int> _axis_values{}; 200*c217d954SCole Faust std::vector<std::vector<TensorShape>> _split_shapes_values{}; 201*c217d954SCole Faust }; 202*c217d954SCole Faust 203*c217d954SCole Faust class SmallSplitShapesDataset final : public SplitShapesDataset 204*c217d954SCole Faust { 205*c217d954SCole Faust public: SmallSplitShapesDataset()206*c217d954SCole Faust SmallSplitShapesDataset() 207*c217d954SCole Faust { 208*c217d954SCole Faust add_config(TensorShape(27U, 3U, 16U, 2U), 2U, std::vector<TensorShape> { TensorShape(27U, 3U, 4U, 2U), 209*c217d954SCole Faust TensorShape(27U, 3U, 4U, 2U), 210*c217d954SCole Faust TensorShape(27U, 3U, 8U, 2U) 211*c217d954SCole Faust }); 212*c217d954SCole Faust } 213*c217d954SCole Faust }; 214*c217d954SCole Faust 215*c217d954SCole Faust } // namespace datasets 216*c217d954SCole Faust } // namespace test 217*c217d954SCole Faust } // namespace arm_compute 218*c217d954SCole Faust #endif /* ARM_COMPUTE_TEST_SPLIT_DATASET */ 219