1*c217d954SCole Faust /* 2*c217d954SCole Faust * Copyright (c) 2018 Arm Limited. 3*c217d954SCole Faust * 4*c217d954SCole Faust * SPDX-License-Identifier: MIT 5*c217d954SCole Faust * 6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy 7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to 8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the 9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is 11*c217d954SCole Faust * furnished to do so, subject to the following conditions: 12*c217d954SCole Faust * 13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all 14*c217d954SCole Faust * copies or substantial portions of the Software. 15*c217d954SCole Faust * 16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22*c217d954SCole Faust * SOFTWARE. 23*c217d954SCole Faust */ 24*c217d954SCole Faust #ifndef ARM_COMPUTE_TEST_STRIDED_SLICE_DATASET 25*c217d954SCole Faust #define ARM_COMPUTE_TEST_STRIDED_SLICE_DATASET 26*c217d954SCole Faust 27*c217d954SCole Faust #include "utils/TypePrinter.h" 28*c217d954SCole Faust 29*c217d954SCole Faust #include "arm_compute/core/Types.h" 30*c217d954SCole Faust 31*c217d954SCole Faust namespace arm_compute 32*c217d954SCole Faust { 33*c217d954SCole Faust namespace test 34*c217d954SCole Faust { 35*c217d954SCole Faust namespace datasets 36*c217d954SCole Faust { 37*c217d954SCole Faust class SliceDataset 38*c217d954SCole Faust { 39*c217d954SCole Faust public: 40*c217d954SCole Faust using type = std::tuple<TensorShape, Coordinates, Coordinates>; 41*c217d954SCole Faust 42*c217d954SCole Faust struct iterator 43*c217d954SCole Faust { iteratoriterator44*c217d954SCole Faust iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it, 45*c217d954SCole Faust std::vector<Coordinates>::const_iterator starts_values_it, 46*c217d954SCole Faust std::vector<Coordinates>::const_iterator ends_values_it) 47*c217d954SCole Faust : _tensor_shapes_it{ std::move(tensor_shapes_it) }, 48*c217d954SCole Faust _starts_values_it{ std::move(starts_values_it) }, 49*c217d954SCole Faust _ends_values_it{ std::move(ends_values_it) } 50*c217d954SCole Faust { 51*c217d954SCole Faust } 52*c217d954SCole Faust descriptioniterator53*c217d954SCole Faust std::string description() const 54*c217d954SCole Faust { 55*c217d954SCole Faust std::stringstream description; 56*c217d954SCole Faust description << "Shape=" << *_tensor_shapes_it << ":"; 57*c217d954SCole Faust description << "Starts=" << *_starts_values_it << ":"; 58*c217d954SCole Faust description << "Ends=" << *_ends_values_it << ":"; 59*c217d954SCole Faust return description.str(); 60*c217d954SCole Faust } 61*c217d954SCole Faust 62*c217d954SCole Faust SliceDataset::type operator*() const 63*c217d954SCole Faust { 64*c217d954SCole Faust return std::make_tuple(*_tensor_shapes_it, *_starts_values_it, *_ends_values_it); 65*c217d954SCole Faust } 66*c217d954SCole Faust 67*c217d954SCole Faust iterator &operator++() 68*c217d954SCole Faust { 69*c217d954SCole Faust ++_tensor_shapes_it; 70*c217d954SCole Faust ++_starts_values_it; 71*c217d954SCole Faust ++_ends_values_it; 72*c217d954SCole Faust return *this; 73*c217d954SCole Faust } 74*c217d954SCole Faust 75*c217d954SCole Faust private: 76*c217d954SCole Faust std::vector<TensorShape>::const_iterator _tensor_shapes_it; 77*c217d954SCole Faust std::vector<Coordinates>::const_iterator _starts_values_it; 78*c217d954SCole Faust std::vector<Coordinates>::const_iterator _ends_values_it; 79*c217d954SCole Faust }; 80*c217d954SCole Faust begin()81*c217d954SCole Faust iterator begin() const 82*c217d954SCole Faust { 83*c217d954SCole Faust return iterator(_tensor_shapes.begin(), _starts_values.begin(), _ends_values.begin()); 84*c217d954SCole Faust } 85*c217d954SCole Faust size()86*c217d954SCole Faust int size() const 87*c217d954SCole Faust { 88*c217d954SCole Faust return std::min(_tensor_shapes.size(), std::min(_starts_values.size(), _ends_values.size())); 89*c217d954SCole Faust } 90*c217d954SCole Faust add_config(TensorShape shape,Coordinates starts,Coordinates ends)91*c217d954SCole Faust void add_config(TensorShape shape, Coordinates starts, Coordinates ends) 92*c217d954SCole Faust { 93*c217d954SCole Faust _tensor_shapes.emplace_back(std::move(shape)); 94*c217d954SCole Faust _starts_values.emplace_back(std::move(starts)); 95*c217d954SCole Faust _ends_values.emplace_back(std::move(ends)); 96*c217d954SCole Faust } 97*c217d954SCole Faust 98*c217d954SCole Faust protected: 99*c217d954SCole Faust SliceDataset() = default; 100*c217d954SCole Faust SliceDataset(SliceDataset &&) = default; 101*c217d954SCole Faust 102*c217d954SCole Faust private: 103*c217d954SCole Faust std::vector<TensorShape> _tensor_shapes{}; 104*c217d954SCole Faust std::vector<Coordinates> _starts_values{}; 105*c217d954SCole Faust std::vector<Coordinates> _ends_values{}; 106*c217d954SCole Faust }; 107*c217d954SCole Faust 108*c217d954SCole Faust class StridedSliceDataset 109*c217d954SCole Faust { 110*c217d954SCole Faust public: 111*c217d954SCole Faust using type = std::tuple<TensorShape, Coordinates, Coordinates, BiStrides, int32_t, int32_t, int32_t>; 112*c217d954SCole Faust 113*c217d954SCole Faust struct iterator 114*c217d954SCole Faust { iteratoriterator115*c217d954SCole Faust iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it, 116*c217d954SCole Faust std::vector<Coordinates>::const_iterator starts_values_it, 117*c217d954SCole Faust std::vector<Coordinates>::const_iterator ends_values_it, 118*c217d954SCole Faust std::vector<BiStrides>::const_iterator strides_values_it, 119*c217d954SCole Faust std::vector<int32_t>::const_iterator begin_mask_values_it, 120*c217d954SCole Faust std::vector<int32_t>::const_iterator end_mask_values_it, 121*c217d954SCole Faust std::vector<int32_t>::const_iterator shrink_mask_values_it) 122*c217d954SCole Faust : _tensor_shapes_it{ std::move(tensor_shapes_it) }, 123*c217d954SCole Faust _starts_values_it{ std::move(starts_values_it) }, 124*c217d954SCole Faust _ends_values_it{ std::move(ends_values_it) }, 125*c217d954SCole Faust _strides_values_it{ std::move(strides_values_it) }, 126*c217d954SCole Faust _begin_mask_values_it{ std::move(begin_mask_values_it) }, 127*c217d954SCole Faust _end_mask_values_it{ std::move(end_mask_values_it) }, 128*c217d954SCole Faust _shrink_mask_values_it{ std::move(shrink_mask_values_it) } 129*c217d954SCole Faust { 130*c217d954SCole Faust } 131*c217d954SCole Faust descriptioniterator132*c217d954SCole Faust std::string description() const 133*c217d954SCole Faust { 134*c217d954SCole Faust std::stringstream description; 135*c217d954SCole Faust description << "Shape=" << *_tensor_shapes_it << ":"; 136*c217d954SCole Faust description << "Starts=" << *_starts_values_it << ":"; 137*c217d954SCole Faust description << "Ends=" << *_ends_values_it << ":"; 138*c217d954SCole Faust description << "Strides=" << *_strides_values_it << ":"; 139*c217d954SCole Faust description << "BeginMask=" << *_begin_mask_values_it << ":"; 140*c217d954SCole Faust description << "EndMask=" << *_end_mask_values_it << ":"; 141*c217d954SCole Faust description << "ShrinkMask=" << *_shrink_mask_values_it << ":"; 142*c217d954SCole Faust return description.str(); 143*c217d954SCole Faust } 144*c217d954SCole Faust 145*c217d954SCole Faust StridedSliceDataset::type operator*() const 146*c217d954SCole Faust { 147*c217d954SCole Faust return std::make_tuple(*_tensor_shapes_it, 148*c217d954SCole Faust *_starts_values_it, *_ends_values_it, *_strides_values_it, 149*c217d954SCole Faust *_begin_mask_values_it, *_end_mask_values_it, *_shrink_mask_values_it); 150*c217d954SCole Faust } 151*c217d954SCole Faust 152*c217d954SCole Faust iterator &operator++() 153*c217d954SCole Faust { 154*c217d954SCole Faust ++_tensor_shapes_it; 155*c217d954SCole Faust ++_starts_values_it; 156*c217d954SCole Faust ++_ends_values_it; 157*c217d954SCole Faust ++_strides_values_it; 158*c217d954SCole Faust ++_begin_mask_values_it; 159*c217d954SCole Faust ++_end_mask_values_it; 160*c217d954SCole Faust ++_shrink_mask_values_it; 161*c217d954SCole Faust 162*c217d954SCole Faust return *this; 163*c217d954SCole Faust } 164*c217d954SCole Faust 165*c217d954SCole Faust private: 166*c217d954SCole Faust std::vector<TensorShape>::const_iterator _tensor_shapes_it; 167*c217d954SCole Faust std::vector<Coordinates>::const_iterator _starts_values_it; 168*c217d954SCole Faust std::vector<Coordinates>::const_iterator _ends_values_it; 169*c217d954SCole Faust std::vector<BiStrides>::const_iterator _strides_values_it; 170*c217d954SCole Faust std::vector<int32_t>::const_iterator _begin_mask_values_it; 171*c217d954SCole Faust std::vector<int32_t>::const_iterator _end_mask_values_it; 172*c217d954SCole Faust std::vector<int32_t>::const_iterator _shrink_mask_values_it; 173*c217d954SCole Faust }; 174*c217d954SCole Faust begin()175*c217d954SCole Faust iterator begin() const 176*c217d954SCole Faust { 177*c217d954SCole Faust return iterator(_tensor_shapes.begin(), 178*c217d954SCole Faust _starts_values.begin(), _ends_values.begin(), _strides_values.begin(), 179*c217d954SCole Faust _begin_mask_values.begin(), _end_mask_values.begin(), _shrink_mask_values.begin()); 180*c217d954SCole Faust } 181*c217d954SCole Faust size()182*c217d954SCole Faust int size() const 183*c217d954SCole Faust { 184*c217d954SCole Faust return std::min(_tensor_shapes.size(), std::min(_starts_values.size(), std::min(_ends_values.size(), _strides_values.size()))); 185*c217d954SCole Faust } 186*c217d954SCole Faust 187*c217d954SCole Faust void add_config(TensorShape shape, 188*c217d954SCole Faust Coordinates starts, Coordinates ends, BiStrides strides, 189*c217d954SCole Faust int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_mask = 0) 190*c217d954SCole Faust { 191*c217d954SCole Faust _tensor_shapes.emplace_back(std::move(shape)); 192*c217d954SCole Faust _starts_values.emplace_back(std::move(starts)); 193*c217d954SCole Faust _ends_values.emplace_back(std::move(ends)); 194*c217d954SCole Faust _strides_values.emplace_back(std::move(strides)); 195*c217d954SCole Faust _begin_mask_values.emplace_back(std::move(begin_mask)); 196*c217d954SCole Faust _end_mask_values.emplace_back(std::move(end_mask)); 197*c217d954SCole Faust _shrink_mask_values.emplace_back(std::move(shrink_mask)); 198*c217d954SCole Faust } 199*c217d954SCole Faust 200*c217d954SCole Faust protected: 201*c217d954SCole Faust StridedSliceDataset() = default; 202*c217d954SCole Faust StridedSliceDataset(StridedSliceDataset &&) = default; 203*c217d954SCole Faust 204*c217d954SCole Faust private: 205*c217d954SCole Faust std::vector<TensorShape> _tensor_shapes{}; 206*c217d954SCole Faust std::vector<Coordinates> _starts_values{}; 207*c217d954SCole Faust std::vector<Coordinates> _ends_values{}; 208*c217d954SCole Faust std::vector<BiStrides> _strides_values{}; 209*c217d954SCole Faust std::vector<int32_t> _begin_mask_values{}; 210*c217d954SCole Faust std::vector<int32_t> _end_mask_values{}; 211*c217d954SCole Faust std::vector<int32_t> _shrink_mask_values{}; 212*c217d954SCole Faust }; 213*c217d954SCole Faust 214*c217d954SCole Faust class SmallSliceDataset final : public SliceDataset 215*c217d954SCole Faust { 216*c217d954SCole Faust public: SmallSliceDataset()217*c217d954SCole Faust SmallSliceDataset() 218*c217d954SCole Faust { 219*c217d954SCole Faust // 1D 220*c217d954SCole Faust add_config(TensorShape(15U), Coordinates(4), Coordinates(9)); 221*c217d954SCole Faust add_config(TensorShape(15U), Coordinates(0), Coordinates(-1)); 222*c217d954SCole Faust // 2D 223*c217d954SCole Faust add_config(TensorShape(15U, 16U), Coordinates(0, 1), Coordinates(5, -1)); 224*c217d954SCole Faust add_config(TensorShape(15U, 16U), Coordinates(4, 1), Coordinates(12, -1)); 225*c217d954SCole Faust // 3D 226*c217d954SCole Faust add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4)); 227*c217d954SCole Faust add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4)); 228*c217d954SCole Faust // 4D 229*c217d954SCole Faust add_config(TensorShape(15U, 16U, 4U, 12U), Coordinates(0, 1, 2, 2), Coordinates(5, -1, 4, 5)); 230*c217d954SCole Faust } 231*c217d954SCole Faust }; 232*c217d954SCole Faust 233*c217d954SCole Faust class LargeSliceDataset final : public SliceDataset 234*c217d954SCole Faust { 235*c217d954SCole Faust public: LargeSliceDataset()236*c217d954SCole Faust LargeSliceDataset() 237*c217d954SCole Faust { 238*c217d954SCole Faust // 1D 239*c217d954SCole Faust add_config(TensorShape(1025U), Coordinates(128), Coordinates(-100)); 240*c217d954SCole Faust // 2D 241*c217d954SCole Faust add_config(TensorShape(372U, 68U), Coordinates(128, 7), Coordinates(368, -1)); 242*c217d954SCole Faust // 3D 243*c217d954SCole Faust add_config(TensorShape(372U, 68U, 12U), Coordinates(128, 7, 2), Coordinates(368, -1, 4)); 244*c217d954SCole Faust // 4D 245*c217d954SCole Faust add_config(TensorShape(372U, 68U, 7U, 4U), Coordinates(128, 7, 2), Coordinates(368, 17, 5)); 246*c217d954SCole Faust } 247*c217d954SCole Faust }; 248*c217d954SCole Faust 249*c217d954SCole Faust class SmallStridedSliceDataset final : public StridedSliceDataset 250*c217d954SCole Faust { 251*c217d954SCole Faust public: SmallStridedSliceDataset()252*c217d954SCole Faust SmallStridedSliceDataset() 253*c217d954SCole Faust { 254*c217d954SCole Faust // 1D 255*c217d954SCole Faust add_config(TensorShape(15U), Coordinates(0), Coordinates(5), BiStrides(2)); 256*c217d954SCole Faust add_config(TensorShape(15U), Coordinates(-1), Coordinates(-8), BiStrides(-2)); 257*c217d954SCole Faust // 2D 258*c217d954SCole Faust add_config(TensorShape(15U, 16U), Coordinates(0, 1), Coordinates(5, -1), BiStrides(2, 1)); 259*c217d954SCole Faust add_config(TensorShape(15U, 16U), Coordinates(4, 1), Coordinates(12, -1), BiStrides(2, 1), 1); 260*c217d954SCole Faust // 3D 261*c217d954SCole Faust add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4), BiStrides(2, 1, 2)); 262*c217d954SCole Faust add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4), BiStrides(2, 1, 2), 0, 1); 263*c217d954SCole Faust // 4D 264*c217d954SCole Faust add_config(TensorShape(15U, 16U, 4U, 12U), Coordinates(0, 1, 2, 2), Coordinates(5, -1, 4, 5), BiStrides(2, 1, 2, 3)); 265*c217d954SCole Faust 266*c217d954SCole Faust // Shrink axis 267*c217d954SCole Faust add_config(TensorShape(1U, 3U, 2U, 3U), Coordinates(0, 1, 0, 0), Coordinates(1, 1, 1, 1), BiStrides(1, 1, 1, 1), 0, 15, 6); 268*c217d954SCole Faust add_config(TensorShape(3U, 2U), Coordinates(0, 0), Coordinates(3U, 1U), BiStrides(1, 1), 0, 0, 2); 269*c217d954SCole Faust add_config(TensorShape(4U, 7U, 7U), Coordinates(0, 0, 0), Coordinates(1U, 1U, 1U), BiStrides(1, 1, 1), 0, 6, 1); 270*c217d954SCole Faust add_config(TensorShape(4U, 7U, 7U), Coordinates(0, 1, 0), Coordinates(1U, 1U, 1U), BiStrides(1, 1, 1), 0, 5, 3); 271*c217d954SCole Faust } 272*c217d954SCole Faust }; 273*c217d954SCole Faust 274*c217d954SCole Faust class LargeStridedSliceDataset final : public StridedSliceDataset 275*c217d954SCole Faust { 276*c217d954SCole Faust public: LargeStridedSliceDataset()277*c217d954SCole Faust LargeStridedSliceDataset() 278*c217d954SCole Faust { 279*c217d954SCole Faust // 1D 280*c217d954SCole Faust add_config(TensorShape(1025U), Coordinates(128), Coordinates(-100), BiStrides(20)); 281*c217d954SCole Faust // 2D 282*c217d954SCole Faust add_config(TensorShape(372U, 68U), Coordinates(128, 7), Coordinates(368, -30), BiStrides(10, 7)); 283*c217d954SCole Faust // 3D 284*c217d954SCole Faust add_config(TensorShape(372U, 68U, 12U), Coordinates(128, 7, -1), Coordinates(368, -30, -5), BiStrides(14, 7, -2)); 285*c217d954SCole Faust // 4D 286*c217d954SCole Faust add_config(TensorShape(372U, 68U, 7U, 4U), Coordinates(128, 7, 2), Coordinates(368, -30, 5), BiStrides(20, 7, 2), 1, 1); 287*c217d954SCole Faust } 288*c217d954SCole Faust }; 289*c217d954SCole Faust } // namespace datasets 290*c217d954SCole Faust } // namespace test 291*c217d954SCole Faust } // namespace arm_compute 292*c217d954SCole Faust #endif /* ARM_COMPUTE_TEST_STRIDED_SLICE_DATASET */ 293