xref: /aosp_15_r20/external/android-nn-driver/test/1.0/FullyConnectedReshape.cpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "../DriverTestHelpers.hpp"
7 
8 DOCTEST_TEST_SUITE("FullyConnectedReshapeTests")
9 {
10 DOCTEST_TEST_CASE("TestFlattenFullyConnectedInput")
11 {
12     using armnn::TensorShape;
13 
14     // Pass through 2d input
15     DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({2,2048}),
16                                              TensorShape({512, 2048})) == TensorShape({2, 2048}));
17 
18     // Trivial flattening of batched channels
19     DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({97,1,1,2048}),
20                                              TensorShape({512, 2048})) == TensorShape({97, 2048}));
21 
22     // Flatten single batch of rows
23     DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({1,97,1,2048}),
24                                              TensorShape({512, 2048})) == TensorShape({97, 2048}));
25 
26     // Flatten single batch of columns
27     DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({1,1,97,2048}),
28                                              TensorShape({512, 2048})) == TensorShape({97, 2048}));
29 
30     // Move batches into input dimension
31     DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({50,1,1,10}),
32                                              TensorShape({512, 20})) == TensorShape({25, 20}));
33 
34     // Flatten single batch of 3D data (e.g. convolution output)
35     DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({1,16,16,10}),
36                                              TensorShape({512, 2560})) == TensorShape({1, 2560}));
37 }
38 
39 }
40