1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "../DriverTestHelpers.hpp" 7 8 DOCTEST_TEST_SUITE("FullyConnectedReshapeTests") 9 { 10 DOCTEST_TEST_CASE("TestFlattenFullyConnectedInput") 11 { 12 using armnn::TensorShape; 13 14 // Pass through 2d input 15 DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({2,2048}), 16 TensorShape({512, 2048})) == TensorShape({2, 2048})); 17 18 // Trivial flattening of batched channels 19 DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({97,1,1,2048}), 20 TensorShape({512, 2048})) == TensorShape({97, 2048})); 21 22 // Flatten single batch of rows 23 DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({1,97,1,2048}), 24 TensorShape({512, 2048})) == TensorShape({97, 2048})); 25 26 // Flatten single batch of columns 27 DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({1,1,97,2048}), 28 TensorShape({512, 2048})) == TensorShape({97, 2048})); 29 30 // Move batches into input dimension 31 DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({50,1,1,10}), 32 TensorShape({512, 20})) == TensorShape({25, 20})); 33 34 // Flatten single batch of 3D data (e.g. convolution output) 35 DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({1,16,16,10}), 36 TensorShape({512, 2560})) == TensorShape({1, 2560})); 37 } 38 39 } 40