1*c217d954SCole Faust /* 2*c217d954SCole Faust * Copyright (c) 2018-2020 Arm Limited. 3*c217d954SCole Faust * 4*c217d954SCole Faust * SPDX-License-Identifier: MIT 5*c217d954SCole Faust * 6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy 7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to 8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the 9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is 11*c217d954SCole Faust * furnished to do so, subject to the following conditions: 12*c217d954SCole Faust * 13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all 14*c217d954SCole Faust * copies or substantial portions of the Software. 15*c217d954SCole Faust * 16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22*c217d954SCole Faust * SOFTWARE. 23*c217d954SCole Faust */ 24*c217d954SCole Faust #ifndef ARM_COMPUTE_TEST_WINOGRAD_OUTPUT_TRANSFORM_DATASET 25*c217d954SCole Faust #define ARM_COMPUTE_TEST_WINOGRAD_OUTPUT_TRANSFORM_DATASET 26*c217d954SCole Faust 27*c217d954SCole Faust #include "utils/TypePrinter.h" 28*c217d954SCole Faust 29*c217d954SCole Faust #include "arm_compute/core/TensorShape.h" 30*c217d954SCole Faust 31*c217d954SCole Faust namespace arm_compute 32*c217d954SCole Faust { 33*c217d954SCole Faust namespace test 34*c217d954SCole Faust { 35*c217d954SCole Faust namespace datasets 36*c217d954SCole Faust { 37*c217d954SCole Faust class WinogradOutputTransformDataset 38*c217d954SCole Faust { 39*c217d954SCole Faust public: 40*c217d954SCole Faust using type = std::tuple<TensorShape, WinogradInfo>; 41*c217d954SCole Faust 42*c217d954SCole Faust struct iterator 43*c217d954SCole Faust { iteratoriterator44*c217d954SCole Faust iterator(std::vector<TensorShape>::const_iterator a_it, 45*c217d954SCole Faust std::vector<WinogradInfo>::const_iterator info_it) 46*c217d954SCole Faust : _a_it{ std::move(a_it) }, 47*c217d954SCole Faust _info_it{ std::move(info_it) } 48*c217d954SCole Faust { 49*c217d954SCole Faust } 50*c217d954SCole Faust descriptioniterator51*c217d954SCole Faust std::string description() const 52*c217d954SCole Faust { 53*c217d954SCole Faust std::stringstream description; 54*c217d954SCole Faust description << "Input=" << *_a_it << ":"; 55*c217d954SCole Faust description << "WinogradInfo=" << *_info_it << ":"; 56*c217d954SCole Faust return description.str(); 57*c217d954SCole Faust } 58*c217d954SCole Faust 59*c217d954SCole Faust WinogradOutputTransformDataset::type operator*() const 60*c217d954SCole Faust { 61*c217d954SCole Faust return std::make_tuple(*_a_it, *_info_it); 62*c217d954SCole Faust } 63*c217d954SCole Faust 64*c217d954SCole Faust iterator &operator++() 65*c217d954SCole Faust { 66*c217d954SCole Faust ++_a_it; 67*c217d954SCole Faust ++_info_it; 68*c217d954SCole Faust 69*c217d954SCole Faust return *this; 70*c217d954SCole Faust } 71*c217d954SCole Faust 72*c217d954SCole Faust private: 73*c217d954SCole Faust std::vector<TensorShape>::const_iterator _a_it; 74*c217d954SCole Faust std::vector<WinogradInfo>::const_iterator _info_it; 75*c217d954SCole Faust }; 76*c217d954SCole Faust begin()77*c217d954SCole Faust iterator begin() const 78*c217d954SCole Faust { 79*c217d954SCole Faust return iterator(_a_shapes.begin(), _info.begin()); 80*c217d954SCole Faust } 81*c217d954SCole Faust size()82*c217d954SCole Faust int size() const 83*c217d954SCole Faust { 84*c217d954SCole Faust return std::min(_a_shapes.size(), _info.size()); 85*c217d954SCole Faust } 86*c217d954SCole Faust add_config(TensorShape a,WinogradInfo b)87*c217d954SCole Faust void add_config(TensorShape a, WinogradInfo b) 88*c217d954SCole Faust { 89*c217d954SCole Faust _a_shapes.emplace_back(std::move(a)); 90*c217d954SCole Faust _info.emplace_back(std::move(b)); 91*c217d954SCole Faust } 92*c217d954SCole Faust 93*c217d954SCole Faust protected: 94*c217d954SCole Faust WinogradOutputTransformDataset() = default; 95*c217d954SCole Faust WinogradOutputTransformDataset(WinogradOutputTransformDataset &&) = default; 96*c217d954SCole Faust 97*c217d954SCole Faust private: 98*c217d954SCole Faust std::vector<TensorShape> _a_shapes{}; 99*c217d954SCole Faust std::vector<WinogradInfo> _info{}; 100*c217d954SCole Faust }; 101*c217d954SCole Faust 102*c217d954SCole Faust class SmallWinogradOutputTransformDatasetNCHW final : public WinogradOutputTransformDataset 103*c217d954SCole Faust { 104*c217d954SCole Faust public: SmallWinogradOutputTransformDatasetNCHW()105*c217d954SCole Faust SmallWinogradOutputTransformDatasetNCHW() 106*c217d954SCole Faust { 107*c217d954SCole Faust // (2x2, 3x3) 108*c217d954SCole Faust add_config(TensorShape(13U, 6U, 16U), WinogradInfo(Size2D(2U, 2U), Size2D(3U, 3U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 109*c217d954SCole Faust add_config(TensorShape(7U, 20U, 16U), WinogradInfo(Size2D(2U, 2U), Size2D(3U, 3U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 110*c217d954SCole Faust add_config(TensorShape(1U, 442U, 16U), WinogradInfo(Size2D(2U, 2U), Size2D(3U, 3U), Size2D(53U, 33U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 111*c217d954SCole Faust add_config(TensorShape(7U, 12U, 16U, 3U), WinogradInfo(Size2D(2U, 2U), Size2D(3U, 3U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 112*c217d954SCole Faust add_config(TensorShape(24U, 49U, 16U, 2U), WinogradInfo(Size2D(2U, 2U), Size2D(3U, 3U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 1), DataLayout::NCHW)); 113*c217d954SCole Faust 114*c217d954SCole Faust // (4x4, 3x3) 115*c217d954SCole Faust add_config(TensorShape(13U, 4U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(10U, 9U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 116*c217d954SCole Faust add_config(TensorShape(13U, 6U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 117*c217d954SCole Faust add_config(TensorShape(7U, 117U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(53U, 33U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 118*c217d954SCole Faust add_config(TensorShape(7U, 4U, 36U, 3U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 119*c217d954SCole Faust add_config(TensorShape(24U, 16U, 36U, 2U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 1), DataLayout::NCHW)); 120*c217d954SCole Faust add_config(TensorShape(7U, 12U, 16U, 5U), WinogradInfo(Size2D(2U, 2U), Size2D(3U, 3U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 121*c217d954SCole Faust 122*c217d954SCole Faust // (2x1, 3x1) 123*c217d954SCole Faust add_config(TensorShape(13U, 18U, 4U), WinogradInfo(Size2D(2U, 1U), Size2D(3U, 1U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 124*c217d954SCole Faust add_config(TensorShape(7U, 44U, 4U), WinogradInfo(Size2D(2U, 1U), Size2D(3U, 1U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 125*c217d954SCole Faust add_config(TensorShape(1U, 891U, 4U), WinogradInfo(Size2D(2U, 1U), Size2D(3U, 1U), Size2D(53U, 33U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 126*c217d954SCole Faust add_config(TensorShape(7U, 30U, 4U, 3U), WinogradInfo(Size2D(2U, 1U), Size2D(3U, 1U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 127*c217d954SCole Faust add_config(TensorShape(24U, 98U, 4U, 2U), WinogradInfo(Size2D(2U, 1U), Size2D(3U, 1U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 128*c217d954SCole Faust 129*c217d954SCole Faust // (1x2, 1x3) 130*c217d954SCole Faust add_config(TensorShape(13U, 14U, 4U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 3U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 131*c217d954SCole Faust add_config(TensorShape(7U, 50U, 4U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 3U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 132*c217d954SCole Faust add_config(TensorShape(1U, 901U, 4U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 3U), Size2D(53U, 33U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 133*c217d954SCole Faust add_config(TensorShape(7U, 32U, 4U, 3U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 3U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 134*c217d954SCole Faust add_config(TensorShape(24U, 98U, 4U, 2U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 3U), Size2D(14U, 14U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 135*c217d954SCole Faust 136*c217d954SCole Faust // (4x1, 3x1) 137*c217d954SCole Faust add_config(TensorShape(13U, 12U, 6U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 138*c217d954SCole Faust add_config(TensorShape(7U, 22U, 6U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 139*c217d954SCole Faust add_config(TensorShape(1U, 462U, 6U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(53U, 33U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 140*c217d954SCole Faust add_config(TensorShape(7U, 20U, 6U, 3U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 141*c217d954SCole Faust add_config(TensorShape(24U, 56U, 6U, 2U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 142*c217d954SCole Faust 143*c217d954SCole Faust // (1x4, 1x3) 144*c217d954SCole Faust add_config(TensorShape(13U, 7U, 6U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 145*c217d954SCole Faust add_config(TensorShape(7U, 30U, 6U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 146*c217d954SCole Faust add_config(TensorShape(1U, 477U, 6U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(53U, 33U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 147*c217d954SCole Faust add_config(TensorShape(7U, 16U, 6U, 3U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 148*c217d954SCole Faust add_config(TensorShape(24U, 56U, 6U, 2U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(14U, 14U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 149*c217d954SCole Faust 150*c217d954SCole Faust // (4x4, 5x5) 151*c217d954SCole Faust add_config(TensorShape(13U, 1U, 64U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 152*c217d954SCole Faust add_config(TensorShape(7U, 4U, 64U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 153*c217d954SCole Faust add_config(TensorShape(5U, 104U, 64U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(53U, 33U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 154*c217d954SCole Faust add_config(TensorShape(7U, 2U, 64U, 3U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 155*c217d954SCole Faust add_config(TensorShape(24U, 9U, 64U, 2U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 1), DataLayout::NCHW)); 156*c217d954SCole Faust add_config(TensorShape(7U, 2U, 64U, 5U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 157*c217d954SCole Faust 158*c217d954SCole Faust // (4x1, 5x1) 159*c217d954SCole Faust add_config(TensorShape(13U, 6U, 8U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 160*c217d954SCole Faust add_config(TensorShape(7U, 22U, 8U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 161*c217d954SCole Faust add_config(TensorShape(5U, 462U, 8U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(53U, 33U), PadStrideInfo(1, 1, 2, 0), DataLayout::NCHW)); 162*c217d954SCole Faust add_config(TensorShape(7U, 10U, 8U, 3U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 163*c217d954SCole Faust add_config(TensorShape(24U, 42U, 8U, 2U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 164*c217d954SCole Faust add_config(TensorShape(7U, 20U, 8U, 5U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(8U, 10U), PadStrideInfo(1, 1, 2, 0), DataLayout::NCHW)); 165*c217d954SCole Faust 166*c217d954SCole Faust // (1x4, 1x5) 167*c217d954SCole Faust add_config(TensorShape(13U, 7U, 8U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 168*c217d954SCole Faust add_config(TensorShape(7U, 20U, 8U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 169*c217d954SCole Faust add_config(TensorShape(5U, 477U, 8U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(53U, 33U), PadStrideInfo(1, 1, 0, 2), DataLayout::NCHW)); 170*c217d954SCole Faust add_config(TensorShape(7U, 16U, 8U, 3U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 171*c217d954SCole Faust add_config(TensorShape(24U, 42U, 8U, 2U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(14U, 14U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 172*c217d954SCole Faust add_config(TensorShape(7U, 24U, 8U, 5U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 2), DataLayout::NCHW)); 173*c217d954SCole Faust } 174*c217d954SCole Faust }; 175*c217d954SCole Faust 176*c217d954SCole Faust class SmallWinogradOutputTransformDatasetNHWC_F16 : public WinogradOutputTransformDataset 177*c217d954SCole Faust { 178*c217d954SCole Faust public: SmallWinogradOutputTransformDatasetNHWC_F16()179*c217d954SCole Faust SmallWinogradOutputTransformDatasetNHWC_F16() 180*c217d954SCole Faust { 181*c217d954SCole Faust // (4x1, 3x1) 182*c217d954SCole Faust add_config(TensorShape(1U, 12U, 6U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 183*c217d954SCole Faust add_config(TensorShape(7U, 22U, 6U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 184*c217d954SCole Faust add_config(TensorShape(1U, 462U, 6U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(53U, 33U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 185*c217d954SCole Faust add_config(TensorShape(16U, 20U, 6U, 3U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 186*c217d954SCole Faust add_config(TensorShape(24U, 56U, 6U, 2U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 187*c217d954SCole Faust 188*c217d954SCole Faust // (1x4, 1x3) 189*c217d954SCole Faust add_config(TensorShape(1U, 7U, 6U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 190*c217d954SCole Faust add_config(TensorShape(7U, 30U, 6U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 191*c217d954SCole Faust add_config(TensorShape(1U, 477U, 6U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(53U, 33U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 192*c217d954SCole Faust add_config(TensorShape(16U, 16U, 6U, 3U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 193*c217d954SCole Faust add_config(TensorShape(24U, 56U, 6U, 2U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(14U, 14U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 194*c217d954SCole Faust 195*c217d954SCole Faust // (4x4, 3x3) 196*c217d954SCole Faust add_config(TensorShape(1U, 4U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(10U, 9U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 197*c217d954SCole Faust add_config(TensorShape(13U, 6U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 198*c217d954SCole Faust add_config(TensorShape(7U, 117U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(53U, 33U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 199*c217d954SCole Faust add_config(TensorShape(16U, 4U, 36U, 3U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 200*c217d954SCole Faust add_config(TensorShape(24U, 16U, 36U, 2U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 1), DataLayout::NHWC)); 201*c217d954SCole Faust 202*c217d954SCole Faust // (4x4, 5x5) 203*c217d954SCole Faust add_config(TensorShape(1U, 1U, 64U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 204*c217d954SCole Faust add_config(TensorShape(7U, 4U, 64U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 205*c217d954SCole Faust add_config(TensorShape(5U, 104U, 64U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(53U, 33U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 206*c217d954SCole Faust add_config(TensorShape(7U, 2U, 64U, 3U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 207*c217d954SCole Faust add_config(TensorShape(16U, 9U, 64U, 2U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 1), DataLayout::NHWC)); 208*c217d954SCole Faust add_config(TensorShape(7U, 2U, 64U, 5U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 209*c217d954SCole Faust 210*c217d954SCole Faust // (4x1, 5x1) 211*c217d954SCole Faust add_config(TensorShape(1U, 6U, 8U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 212*c217d954SCole Faust add_config(TensorShape(7U, 22U, 8U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 213*c217d954SCole Faust add_config(TensorShape(5U, 462U, 8U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(53U, 33U), PadStrideInfo(1, 1, 2, 0), DataLayout::NHWC)); 214*c217d954SCole Faust add_config(TensorShape(7U, 10U, 8U, 3U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 215*c217d954SCole Faust add_config(TensorShape(16U, 42U, 8U, 2U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 216*c217d954SCole Faust add_config(TensorShape(7U, 20U, 8U, 5U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(8U, 10U), PadStrideInfo(1, 1, 2, 0), DataLayout::NHWC)); 217*c217d954SCole Faust 218*c217d954SCole Faust // (1x4, 1x5) 219*c217d954SCole Faust add_config(TensorShape(1U, 7U, 8U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(7U, 6U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 220*c217d954SCole Faust add_config(TensorShape(7U, 20U, 8U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 221*c217d954SCole Faust add_config(TensorShape(5U, 477U, 8U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(53U, 33U), PadStrideInfo(1, 1, 0, 2), DataLayout::NHWC)); 222*c217d954SCole Faust add_config(TensorShape(7U, 16U, 8U, 3U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 223*c217d954SCole Faust add_config(TensorShape(16U, 42U, 8U, 2U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(14U, 14U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 224*c217d954SCole Faust add_config(TensorShape(7U, 24U, 8U, 5U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 2), DataLayout::NHWC)); 225*c217d954SCole Faust 226*c217d954SCole Faust // (2x1, 7x1) 227*c217d954SCole Faust add_config(TensorShape(1U, 18U, 8U), WinogradInfo(Size2D(2U, 1U), Size2D(7U, 1U), Size2D(9U, 9U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 228*c217d954SCole Faust add_config(TensorShape(7U, 22U, 8U), WinogradInfo(Size2D(2U, 1U), Size2D(7U, 1U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 229*c217d954SCole Faust add_config(TensorShape(5U, 858U, 8U), WinogradInfo(Size2D(2U, 1U), Size2D(7U, 1U), Size2D(53U, 33U), PadStrideInfo(1, 1, 2, 0), DataLayout::NHWC)); 230*c217d954SCole Faust add_config(TensorShape(7U, 10U, 8U, 3U), WinogradInfo(Size2D(2U, 1U), Size2D(7U, 1U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 231*c217d954SCole Faust add_config(TensorShape(16U, 70U, 8U, 2U), WinogradInfo(Size2D(2U, 1U), Size2D(7U, 1U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 232*c217d954SCole Faust add_config(TensorShape(7U, 30U, 8U, 5U), WinogradInfo(Size2D(2U, 1U), Size2D(7U, 1U), Size2D(8U, 10U), PadStrideInfo(1, 1, 2, 0), DataLayout::NHWC)); 233*c217d954SCole Faust 234*c217d954SCole Faust // (1x2, 1x7) 235*c217d954SCole Faust add_config(TensorShape(1U, 18U, 8U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 7U), Size2D(9U, 9U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 236*c217d954SCole Faust add_config(TensorShape(7U, 30U, 8U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 7U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 237*c217d954SCole Faust add_config(TensorShape(5U, 848U, 8U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 7U), Size2D(53U, 33U), PadStrideInfo(1, 1, 0, 2), DataLayout::NHWC)); 238*c217d954SCole Faust add_config(TensorShape(7U, 16U, 8U, 3U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 7U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 239*c217d954SCole Faust add_config(TensorShape(16U, 70U, 8U, 2U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 7U), Size2D(14U, 14U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 240*c217d954SCole Faust add_config(TensorShape(7U, 32U, 8U, 5U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 7U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 2), DataLayout::NHWC)); 241*c217d954SCole Faust } 242*c217d954SCole Faust }; 243*c217d954SCole Faust 244*c217d954SCole Faust class SmallWinogradOutputTransformDatasetNHWC_F32 : public SmallWinogradOutputTransformDatasetNHWC_F16 245*c217d954SCole Faust { 246*c217d954SCole Faust public: SmallWinogradOutputTransformDatasetNHWC_F32()247*c217d954SCole Faust SmallWinogradOutputTransformDatasetNHWC_F32() 248*c217d954SCole Faust : SmallWinogradOutputTransformDatasetNHWC_F16() 249*c217d954SCole Faust { 250*c217d954SCole Faust // (2x2, 7x7) 251*c217d954SCole Faust add_config(TensorShape(1U, 4U, 64U), WinogradInfo(Size2D(2U, 2U), Size2D(7U, 7U), Size2D(9U, 9U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 252*c217d954SCole Faust add_config(TensorShape(7U, 6U, 64U), WinogradInfo(Size2D(2U, 2U), Size2D(7U, 7U), Size2D(10U, 11U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 253*c217d954SCole Faust add_config(TensorShape(5U, 360U, 64U), WinogradInfo(Size2D(2U, 2U), Size2D(7U, 7U), Size2D(53U, 33U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 254*c217d954SCole Faust add_config(TensorShape(7U, 2U, 64U, 3U), WinogradInfo(Size2D(2U, 2U), Size2D(7U, 7U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 255*c217d954SCole Faust add_config(TensorShape(16U, 25U, 64U, 2U), WinogradInfo(Size2D(2U, 2U), Size2D(7U, 7U), Size2D(14U, 14U), PadStrideInfo(1, 1, 1, 1), DataLayout::NHWC)); 256*c217d954SCole Faust add_config(TensorShape(7U, 2U, 64U, 5U), WinogradInfo(Size2D(2U, 2U), Size2D(7U, 7U), Size2D(8U, 10U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 257*c217d954SCole Faust } 258*c217d954SCole Faust }; 259*c217d954SCole Faust 260*c217d954SCole Faust class LargeWinogradOutputTransformDatasetNCHW : public WinogradOutputTransformDataset 261*c217d954SCole Faust { 262*c217d954SCole Faust public: LargeWinogradOutputTransformDatasetNCHW()263*c217d954SCole Faust LargeWinogradOutputTransformDatasetNCHW() 264*c217d954SCole Faust { 265*c217d954SCole Faust // (2x2, 3x3) 266*c217d954SCole Faust add_config(TensorShape(64U, 12544U, 16U), WinogradInfo(Size2D(2U, 2U), Size2D(3U, 3U), Size2D(224U, 224U), PadStrideInfo(1, 1, 1, 1), DataLayout::NCHW)); 267*c217d954SCole Faust add_config(TensorShape(32U, 3080U, 16U), WinogradInfo(Size2D(2U, 2U), Size2D(3U, 3U), Size2D(112U, 112U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 268*c217d954SCole Faust add_config(TensorShape(13U, 756U, 16U), WinogradInfo(Size2D(2U, 2U), Size2D(3U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 269*c217d954SCole Faust add_config(TensorShape(64U, 12544U, 16U, 3U), WinogradInfo(Size2D(2U, 2U), Size2D(3U, 3U), Size2D(224U, 224U), PadStrideInfo(1, 1, 1, 1), DataLayout::NCHW)); 270*c217d954SCole Faust add_config(TensorShape(32U, 3080U, 16U, 2U), WinogradInfo(Size2D(2U, 2U), Size2D(3U, 3U), Size2D(112U, 112U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 271*c217d954SCole Faust add_config(TensorShape(13U, 756U, 16U, 5U), WinogradInfo(Size2D(2U, 2U), Size2D(3U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 272*c217d954SCole Faust 273*c217d954SCole Faust // (4x4, 3x3) 274*c217d954SCole Faust add_config(TensorShape(64U, 3136U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(224U, 224U), PadStrideInfo(1, 1, 1, 1), DataLayout::NCHW)); 275*c217d954SCole Faust add_config(TensorShape(32U, 784U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(112U, 112U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 276*c217d954SCole Faust add_config(TensorShape(13U, 196U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 277*c217d954SCole Faust add_config(TensorShape(64U, 3136U, 36U, 3U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(224U, 224U), PadStrideInfo(1, 1, 1, 1), DataLayout::NCHW)); 278*c217d954SCole Faust add_config(TensorShape(32U, 784U, 36U, 2U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(112U, 112U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 279*c217d954SCole Faust add_config(TensorShape(13U, 196U, 36U, 5U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 280*c217d954SCole Faust 281*c217d954SCole Faust // (2x1, 3x1) 282*c217d954SCole Faust add_config(TensorShape(64U, 24976U, 4U), WinogradInfo(Size2D(2U, 1U), Size2D(3U, 1U), Size2D(224U, 223U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 283*c217d954SCole Faust add_config(TensorShape(32U, 6160U, 4U), WinogradInfo(Size2D(2U, 1U), Size2D(3U, 1U), Size2D(112U, 110U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 284*c217d954SCole Faust add_config(TensorShape(13U, 1568U, 4U), WinogradInfo(Size2D(2U, 1U), Size2D(3U, 1U), Size2D(56U, 56U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 285*c217d954SCole Faust add_config(TensorShape(64U, 24753U, 4U, 3U), WinogradInfo(Size2D(2U, 1U), Size2D(3U, 1U), Size2D(224U, 223U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 286*c217d954SCole Faust add_config(TensorShape(32U, 6050U, 4U, 2U), WinogradInfo(Size2D(2U, 1U), Size2D(3U, 1U), Size2D(112U, 110U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 287*c217d954SCole Faust add_config(TensorShape(13U, 1512U, 4U, 5U), WinogradInfo(Size2D(2U, 1U), Size2D(3U, 1U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 288*c217d954SCole Faust 289*c217d954SCole Faust // (1x2, 1x3) 290*c217d954SCole Faust add_config(TensorShape(64U, 25088U, 4U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 3U), Size2D(224U, 223U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 291*c217d954SCole Faust add_config(TensorShape(32U, 6160U, 4U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 3U), Size2D(112U, 110U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 292*c217d954SCole Faust add_config(TensorShape(13U, 1568U, 4U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 293*c217d954SCole Faust add_config(TensorShape(64U, 24864U, 4U, 3U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 3U), Size2D(224U, 223U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 294*c217d954SCole Faust add_config(TensorShape(32U, 6048U, 4U, 2U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 3U), Size2D(112U, 110U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 295*c217d954SCole Faust add_config(TensorShape(13U, 1512U, 4U, 5U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 296*c217d954SCole Faust 297*c217d954SCole Faust // (4x1, 3x1) 298*c217d954SCole Faust add_config(TensorShape(64U, 12488U, 6U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(224U, 223U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 299*c217d954SCole Faust add_config(TensorShape(32U, 3080U, 6U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(112U, 110U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 300*c217d954SCole Faust add_config(TensorShape(13U, 784U, 6U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(56U, 56U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 301*c217d954SCole Faust add_config(TensorShape(64U, 12488U, 6U, 3U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(224U, 223U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 302*c217d954SCole Faust add_config(TensorShape(32U, 3080U, 6U, 2U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(112U, 110U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 303*c217d954SCole Faust add_config(TensorShape(13U, 784U, 6U, 5U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 304*c217d954SCole Faust 305*c217d954SCole Faust // (1x4, 1x3) 306*c217d954SCole Faust add_config(TensorShape(64U, 12544U, 6U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(224U, 223U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 307*c217d954SCole Faust add_config(TensorShape(32U, 3136U, 6U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(112U, 110U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 308*c217d954SCole Faust add_config(TensorShape(13U, 784U, 6U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 309*c217d954SCole Faust add_config(TensorShape(64U, 12544U, 6U, 3U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(224U, 223U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 310*c217d954SCole Faust add_config(TensorShape(32U, 3024U, 6U, 2U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(112U, 110U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 311*c217d954SCole Faust add_config(TensorShape(13U, 784U, 6U, 5U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 312*c217d954SCole Faust 313*c217d954SCole Faust // (4x4, 5x5) 314*c217d954SCole Faust add_config(TensorShape(32U, 756U, 64U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(112U, 112U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 315*c217d954SCole Faust add_config(TensorShape(13U, 182U, 64U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 316*c217d954SCole Faust add_config(TensorShape(32U, 756U, 64U, 2U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(112U, 112U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 317*c217d954SCole Faust add_config(TensorShape(13U, 182U, 64U, 5U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 318*c217d954SCole Faust 319*c217d954SCole Faust // (4x1, 5x1) 320*c217d954SCole Faust add_config(TensorShape(32U, 3136U, 8U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(112U, 112U), PadStrideInfo(1, 1, 2, 0), DataLayout::NCHW)); 321*c217d954SCole Faust add_config(TensorShape(13U, 784U, 8U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(56U, 56U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 322*c217d954SCole Faust add_config(TensorShape(32U, 3024U, 8U, 2U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(112U, 112U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 323*c217d954SCole Faust add_config(TensorShape(13U, 784U, 8U, 5U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(56U, 56U), PadStrideInfo(1, 1, 1, 0), DataLayout::NCHW)); 324*c217d954SCole Faust 325*c217d954SCole Faust // (1x4, 1x5) 326*c217d954SCole Faust add_config(TensorShape(32U, 3136U, 8U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(112U, 112U), PadStrideInfo(1, 1, 0, 2), DataLayout::NCHW)); 327*c217d954SCole Faust add_config(TensorShape(13U, 784U, 8U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 328*c217d954SCole Faust add_config(TensorShape(32U, 3024U, 8U, 2U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(112U, 112U), PadStrideInfo(1, 1, 0, 0), DataLayout::NCHW)); 329*c217d954SCole Faust add_config(TensorShape(13U, 784U, 8U, 5U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NCHW)); 330*c217d954SCole Faust } 331*c217d954SCole Faust }; 332*c217d954SCole Faust 333*c217d954SCole Faust class LargeWinogradOutputTransformDatasetNHWC_F16 : public WinogradOutputTransformDataset 334*c217d954SCole Faust { 335*c217d954SCole Faust public: LargeWinogradOutputTransformDatasetNHWC_F16()336*c217d954SCole Faust LargeWinogradOutputTransformDatasetNHWC_F16() 337*c217d954SCole Faust { 338*c217d954SCole Faust // (4x1, 3x1) 339*c217d954SCole Faust add_config(TensorShape(64U, 12488U, 6U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(224U, 223U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 340*c217d954SCole Faust add_config(TensorShape(32U, 3080U, 6U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(112U, 110U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 341*c217d954SCole Faust add_config(TensorShape(13U, 784U, 6U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(56U, 56U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 342*c217d954SCole Faust add_config(TensorShape(64U, 12488U, 6U, 3U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(224U, 223U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 343*c217d954SCole Faust add_config(TensorShape(32U, 3080U, 6U, 2U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(112U, 110U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 344*c217d954SCole Faust add_config(TensorShape(13U, 784U, 6U, 5U), WinogradInfo(Size2D(4U, 1U), Size2D(3U, 1U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 345*c217d954SCole Faust 346*c217d954SCole Faust // (1x4, 1x3) 347*c217d954SCole Faust add_config(TensorShape(64U, 12544U, 6U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(224U, 223U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 348*c217d954SCole Faust add_config(TensorShape(32U, 3136U, 6U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(112U, 110U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 349*c217d954SCole Faust add_config(TensorShape(13U, 784U, 6U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 350*c217d954SCole Faust add_config(TensorShape(64U, 12544U, 6U, 3U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(224U, 223U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 351*c217d954SCole Faust add_config(TensorShape(32U, 3024U, 6U, 2U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(112U, 110U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 352*c217d954SCole Faust add_config(TensorShape(13U, 784U, 6U, 5U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 353*c217d954SCole Faust 354*c217d954SCole Faust // (4x4, 3x3) 355*c217d954SCole Faust add_config(TensorShape(64U, 3136U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(224U, 224U), PadStrideInfo(1, 1, 1, 1), DataLayout::NHWC)); 356*c217d954SCole Faust add_config(TensorShape(32U, 784U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(112U, 112U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 357*c217d954SCole Faust add_config(TensorShape(13U, 196U, 36U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 358*c217d954SCole Faust add_config(TensorShape(64U, 3136U, 36U, 3U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(224U, 224U), PadStrideInfo(1, 1, 1, 1), DataLayout::NHWC)); 359*c217d954SCole Faust add_config(TensorShape(32U, 784U, 36U, 2U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(112U, 112U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 360*c217d954SCole Faust add_config(TensorShape(13U, 196U, 36U, 5U), WinogradInfo(Size2D(4U, 4U), Size2D(3U, 3U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 361*c217d954SCole Faust 362*c217d954SCole Faust // (4x4, 5x5) 363*c217d954SCole Faust add_config(TensorShape(32U, 756U, 64U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(112U, 112U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 364*c217d954SCole Faust add_config(TensorShape(13U, 182U, 64U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 365*c217d954SCole Faust add_config(TensorShape(32U, 756U, 64U, 2U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(112U, 112U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 366*c217d954SCole Faust add_config(TensorShape(13U, 182U, 64U, 5U), WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 367*c217d954SCole Faust 368*c217d954SCole Faust // (4x1, 5x1) 369*c217d954SCole Faust add_config(TensorShape(32U, 3136U, 8U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(112U, 112U), PadStrideInfo(1, 1, 2, 0), DataLayout::NHWC)); 370*c217d954SCole Faust add_config(TensorShape(13U, 784U, 8U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(56U, 56U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 371*c217d954SCole Faust add_config(TensorShape(32U, 3024U, 8U, 2U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(112U, 112U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 372*c217d954SCole Faust add_config(TensorShape(13U, 784U, 8U, 5U), WinogradInfo(Size2D(4U, 1U), Size2D(5U, 1U), Size2D(56U, 56U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 373*c217d954SCole Faust 374*c217d954SCole Faust // (1x4, 1x5) 375*c217d954SCole Faust add_config(TensorShape(32U, 3136U, 8U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(112U, 112U), PadStrideInfo(1, 1, 0, 2), DataLayout::NHWC)); 376*c217d954SCole Faust add_config(TensorShape(13U, 784U, 8U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 377*c217d954SCole Faust add_config(TensorShape(32U, 3024U, 8U, 2U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(112U, 112U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 378*c217d954SCole Faust add_config(TensorShape(13U, 784U, 8U, 5U), WinogradInfo(Size2D(1U, 4U), Size2D(1U, 5U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 379*c217d954SCole Faust 380*c217d954SCole Faust // (2x1, 7x1) 381*c217d954SCole Faust add_config(TensorShape(32U, 6160U, 8U), WinogradInfo(Size2D(2U, 1U), Size2D(7U, 1U), Size2D(112U, 112U), PadStrideInfo(1, 1, 2, 0), DataLayout::NHWC)); 382*c217d954SCole Faust add_config(TensorShape(13U, 1456U, 8U), WinogradInfo(Size2D(2U, 1U), Size2D(7U, 1U), Size2D(56U, 56U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 383*c217d954SCole Faust add_config(TensorShape(32U, 5936U, 8U, 2U), WinogradInfo(Size2D(2U, 1U), Size2D(7U, 1U), Size2D(112U, 112U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 384*c217d954SCole Faust add_config(TensorShape(13U, 1456U, 8U, 5U), WinogradInfo(Size2D(2U, 1U), Size2D(7U, 1U), Size2D(56U, 56U), PadStrideInfo(1, 1, 1, 0), DataLayout::NHWC)); 385*c217d954SCole Faust 386*c217d954SCole Faust // (1x2, 1x7) 387*c217d954SCole Faust add_config(TensorShape(32U, 6160U, 8U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 7U), Size2D(112U, 112U), PadStrideInfo(1, 1, 0, 2), DataLayout::NHWC)); 388*c217d954SCole Faust add_config(TensorShape(13U, 1456U, 8U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 7U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 389*c217d954SCole Faust add_config(TensorShape(32U, 5936U, 8U, 2U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 7U), Size2D(112U, 112U), PadStrideInfo(1, 1, 0, 0), DataLayout::NHWC)); 390*c217d954SCole Faust add_config(TensorShape(13U, 1456U, 8U, 5U), WinogradInfo(Size2D(1U, 2U), Size2D(1U, 7U), Size2D(56U, 56U), PadStrideInfo(1, 1, 0, 1), DataLayout::NHWC)); 391*c217d954SCole Faust } 392*c217d954SCole Faust }; 393*c217d954SCole Faust 394*c217d954SCole Faust class LargeWinogradOutputTransformDatasetNHWC_F32 : public LargeWinogradOutputTransformDatasetNHWC_F16 395*c217d954SCole Faust { 396*c217d954SCole Faust public: LargeWinogradOutputTransformDatasetNHWC_F32()397*c217d954SCole Faust LargeWinogradOutputTransformDatasetNHWC_F32() 398*c217d954SCole Faust : LargeWinogradOutputTransformDatasetNHWC_F16() 399*c217d954SCole Faust { 400*c217d954SCole Faust } 401*c217d954SCole Faust }; 402*c217d954SCole Faust } // namespace datasets 403*c217d954SCole Faust } // namespace test 404*c217d954SCole Faust } // namespace arm_compute 405*c217d954SCole Faust #endif /* ARM_COMPUTE_TEST_WINOGRAD_OUTPUT_TRANSFORM_DATASET */