xref: /aosp_15_r20/external/eigen/unsupported/test/cxx11_tensor_trace.cpp (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2017 Gagan Goel <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 using Eigen::array;
16 
17 template <int DataLayout>
test_0D_trace()18 static void test_0D_trace() {
19   Tensor<float, 0, DataLayout> tensor;
20   tensor.setRandom();
21   array<ptrdiff_t, 0> dims;
22   Tensor<float, 0, DataLayout> result = tensor.trace(dims);
23   VERIFY_IS_EQUAL(result(), tensor());
24 }
25 
26 
27 template <int DataLayout>
test_all_dimensions_trace()28 static void test_all_dimensions_trace() {
29   Tensor<float, 3, DataLayout> tensor1(5, 5, 5);
30   tensor1.setRandom();
31   Tensor<float, 0, DataLayout> result1 = tensor1.trace();
32   VERIFY_IS_EQUAL(result1.rank(), 0);
33   float sum = 0.0f;
34   for (int i = 0; i < 5; ++i) {
35     sum += tensor1(i, i, i);
36   }
37   VERIFY_IS_EQUAL(result1(), sum);
38 
39   Tensor<float, 5, DataLayout> tensor2(7, 7, 7, 7, 7);
40   tensor2.setRandom();
41   array<ptrdiff_t, 5> dims = { { 2, 1, 0, 3, 4 } };
42   Tensor<float, 0, DataLayout> result2 = tensor2.trace(dims);
43   VERIFY_IS_EQUAL(result2.rank(), 0);
44   sum = 0.0f;
45   for (int i = 0; i < 7; ++i) {
46     sum += tensor2(i, i, i, i, i);
47   }
48   VERIFY_IS_EQUAL(result2(), sum);
49 }
50 
51 
52 template <int DataLayout>
test_simple_trace()53 static void test_simple_trace() {
54   Tensor<float, 3, DataLayout> tensor1(3, 5, 3);
55   tensor1.setRandom();
56   array<ptrdiff_t, 2> dims1 = { { 0, 2 } };
57   Tensor<float, 1, DataLayout> result1 = tensor1.trace(dims1);
58   VERIFY_IS_EQUAL(result1.rank(), 1);
59   VERIFY_IS_EQUAL(result1.dimension(0), 5);
60   float sum = 0.0f;
61   for (int i = 0; i < 5; ++i) {
62     sum = 0.0f;
63     for (int j = 0; j < 3; ++j) {
64       sum += tensor1(j, i, j);
65     }
66     VERIFY_IS_EQUAL(result1(i), sum);
67   }
68 
69   Tensor<float, 4, DataLayout> tensor2(5, 5, 7, 7);
70   tensor2.setRandom();
71   array<ptrdiff_t, 2> dims2 = { { 2, 3 } };
72   Tensor<float, 2, DataLayout> result2 = tensor2.trace(dims2);
73   VERIFY_IS_EQUAL(result2.rank(), 2);
74   VERIFY_IS_EQUAL(result2.dimension(0), 5);
75   VERIFY_IS_EQUAL(result2.dimension(1), 5);
76   for (int i = 0; i < 5; ++i) {
77     for (int j = 0; j < 5; ++j) {
78       sum = 0.0f;
79       for (int k = 0; k < 7; ++k) {
80         sum += tensor2(i, j, k, k);
81       }
82       VERIFY_IS_EQUAL(result2(i, j), sum);
83     }
84   }
85 
86   array<ptrdiff_t, 2> dims3 = { { 1, 0 } };
87   Tensor<float, 2, DataLayout> result3 = tensor2.trace(dims3);
88   VERIFY_IS_EQUAL(result3.rank(), 2);
89   VERIFY_IS_EQUAL(result3.dimension(0), 7);
90   VERIFY_IS_EQUAL(result3.dimension(1), 7);
91   for (int i = 0; i < 7; ++i) {
92     for (int j = 0; j < 7; ++j) {
93       sum = 0.0f;
94       for (int k = 0; k < 5; ++k) {
95         sum += tensor2(k, k, i, j);
96       }
97       VERIFY_IS_EQUAL(result3(i, j), sum);
98     }
99   }
100 
101   Tensor<float, 5, DataLayout> tensor3(3, 7, 3, 7, 3);
102   tensor3.setRandom();
103   array<ptrdiff_t, 3> dims4 = { { 0, 2, 4 } };
104   Tensor<float, 2, DataLayout> result4 = tensor3.trace(dims4);
105   VERIFY_IS_EQUAL(result4.rank(), 2);
106   VERIFY_IS_EQUAL(result4.dimension(0), 7);
107   VERIFY_IS_EQUAL(result4.dimension(1), 7);
108   for (int i = 0; i < 7; ++i) {
109     for (int j = 0; j < 7; ++j) {
110       sum = 0.0f;
111       for (int k = 0; k < 3; ++k) {
112         sum += tensor3(k, i, k, j, k);
113       }
114       VERIFY_IS_EQUAL(result4(i, j), sum);
115     }
116   }
117 
118   Tensor<float, 5, DataLayout> tensor4(3, 7, 4, 7, 5);
119   tensor4.setRandom();
120   array<ptrdiff_t, 2> dims5 = { { 1, 3 } };
121   Tensor<float, 3, DataLayout> result5 = tensor4.trace(dims5);
122   VERIFY_IS_EQUAL(result5.rank(), 3);
123   VERIFY_IS_EQUAL(result5.dimension(0), 3);
124   VERIFY_IS_EQUAL(result5.dimension(1), 4);
125   VERIFY_IS_EQUAL(result5.dimension(2), 5);
126   for (int i = 0; i < 3; ++i) {
127     for (int j = 0; j < 4; ++j) {
128       for (int k = 0; k < 5; ++k) {
129         sum = 0.0f;
130         for (int l = 0; l < 7; ++l) {
131           sum += tensor4(i, l, j, l, k);
132         }
133         VERIFY_IS_EQUAL(result5(i, j, k), sum);
134       }
135     }
136   }
137 }
138 
139 
140 template<int DataLayout>
test_trace_in_expr()141 static void test_trace_in_expr() {
142   Tensor<float, 4, DataLayout> tensor(2, 3, 5, 3);
143   tensor.setRandom();
144   array<ptrdiff_t, 2> dims = { { 1, 3 } };
145   Tensor<float, 2, DataLayout> result(2, 5);
146   result = result.constant(1.0f) - tensor.trace(dims);
147   VERIFY_IS_EQUAL(result.rank(), 2);
148   VERIFY_IS_EQUAL(result.dimension(0), 2);
149   VERIFY_IS_EQUAL(result.dimension(1), 5);
150   float sum = 0.0f;
151   for (int i = 0; i < 2; ++i) {
152     for (int j = 0; j < 5; ++j) {
153       sum = 0.0f;
154       for (int k = 0; k < 3; ++k) {
155         sum += tensor(i, k, j, k);
156       }
157       VERIFY_IS_EQUAL(result(i, j), 1.0f - sum);
158     }
159   }
160 }
161 
162 
EIGEN_DECLARE_TEST(cxx11_tensor_trace)163 EIGEN_DECLARE_TEST(cxx11_tensor_trace) {
164   CALL_SUBTEST(test_0D_trace<ColMajor>());
165   CALL_SUBTEST(test_0D_trace<RowMajor>());
166   CALL_SUBTEST(test_all_dimensions_trace<ColMajor>());
167   CALL_SUBTEST(test_all_dimensions_trace<RowMajor>());
168   CALL_SUBTEST(test_simple_trace<ColMajor>());
169   CALL_SUBTEST(test_simple_trace<RowMajor>());
170   CALL_SUBTEST(test_trace_in_expr<ColMajor>());
171   CALL_SUBTEST(test_trace_in_expr<RowMajor>());
172 }
173