xref: /aosp_15_r20/external/eigen/unsupported/test/kronecker_product.cpp (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2011 Kolja Brix <[email protected]>
5 // Copyright (C) 2011 Andreas Platen <[email protected]>
6 // Copyright (C) 2012 Chen-Pang He <[email protected]>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla
9 // Public License v. 2.0. If a copy of the MPL was not distributed
10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11 
12 
13 #ifdef EIGEN_TEST_PART_1
14 
15 #include "sparse.h"
16 #include <Eigen/SparseExtra>
17 #include <Eigen/KroneckerProduct>
18 
19 template<typename MatrixType>
check_dimension(const MatrixType & ab,const int rows,const int cols)20 void check_dimension(const MatrixType& ab, const int rows,  const int cols)
21 {
22   VERIFY_IS_EQUAL(ab.rows(), rows);
23   VERIFY_IS_EQUAL(ab.cols(), cols);
24 }
25 
26 
27 template<typename MatrixType>
check_kronecker_product(const MatrixType & ab)28 void check_kronecker_product(const MatrixType& ab)
29 {
30   VERIFY_IS_EQUAL(ab.rows(), 6);
31   VERIFY_IS_EQUAL(ab.cols(), 6);
32   VERIFY_IS_EQUAL(ab.nonZeros(),  36);
33   VERIFY_IS_APPROX(ab.coeff(0,0), -0.4017367630386106);
34   VERIFY_IS_APPROX(ab.coeff(0,1),  0.1056863433932735);
35   VERIFY_IS_APPROX(ab.coeff(0,2), -0.7255206194554212);
36   VERIFY_IS_APPROX(ab.coeff(0,3),  0.1908653336744706);
37   VERIFY_IS_APPROX(ab.coeff(0,4),  0.350864567234111);
38   VERIFY_IS_APPROX(ab.coeff(0,5), -0.0923032108308013);
39   VERIFY_IS_APPROX(ab.coeff(1,0),  0.415417514804677);
40   VERIFY_IS_APPROX(ab.coeff(1,1), -0.2369227701722048);
41   VERIFY_IS_APPROX(ab.coeff(1,2),  0.7502275131458511);
42   VERIFY_IS_APPROX(ab.coeff(1,3), -0.4278731019742696);
43   VERIFY_IS_APPROX(ab.coeff(1,4), -0.3628129162264507);
44   VERIFY_IS_APPROX(ab.coeff(1,5),  0.2069210808481275);
45   VERIFY_IS_APPROX(ab.coeff(2,0),  0.05465890160863986);
46   VERIFY_IS_APPROX(ab.coeff(2,1), -0.2634092511419858);
47   VERIFY_IS_APPROX(ab.coeff(2,2),  0.09871180285793758);
48   VERIFY_IS_APPROX(ab.coeff(2,3), -0.4757066334017702);
49   VERIFY_IS_APPROX(ab.coeff(2,4), -0.04773740823058334);
50   VERIFY_IS_APPROX(ab.coeff(2,5),  0.2300535609645254);
51   VERIFY_IS_APPROX(ab.coeff(3,0), -0.8172945853260133);
52   VERIFY_IS_APPROX(ab.coeff(3,1),  0.2150086428359221);
53   VERIFY_IS_APPROX(ab.coeff(3,2),  0.5825113847292743);
54   VERIFY_IS_APPROX(ab.coeff(3,3), -0.1532433770097174);
55   VERIFY_IS_APPROX(ab.coeff(3,4), -0.329383387282399);
56   VERIFY_IS_APPROX(ab.coeff(3,5),  0.08665207912033064);
57   VERIFY_IS_APPROX(ab.coeff(4,0),  0.8451267514863225);
58   VERIFY_IS_APPROX(ab.coeff(4,1), -0.481996458918977);
59   VERIFY_IS_APPROX(ab.coeff(4,2), -0.6023482390791535);
60   VERIFY_IS_APPROX(ab.coeff(4,3),  0.3435339347164565);
61   VERIFY_IS_APPROX(ab.coeff(4,4),  0.3406002157428891);
62   VERIFY_IS_APPROX(ab.coeff(4,5), -0.1942526344200915);
63   VERIFY_IS_APPROX(ab.coeff(5,0),  0.1111982482925399);
64   VERIFY_IS_APPROX(ab.coeff(5,1), -0.5358806424754169);
65   VERIFY_IS_APPROX(ab.coeff(5,2), -0.07925446559335647);
66   VERIFY_IS_APPROX(ab.coeff(5,3),  0.3819388757769038);
67   VERIFY_IS_APPROX(ab.coeff(5,4),  0.04481475387219876);
68   VERIFY_IS_APPROX(ab.coeff(5,5), -0.2159688616158057);
69 }
70 
71 
72 template<typename MatrixType>
check_sparse_kronecker_product(const MatrixType & ab)73 void check_sparse_kronecker_product(const MatrixType& ab)
74 {
75   VERIFY_IS_EQUAL(ab.rows(), 12);
76   VERIFY_IS_EQUAL(ab.cols(), 10);
77   VERIFY_IS_EQUAL(ab.nonZeros(), 3*2);
78   VERIFY_IS_APPROX(ab.coeff(3,0), -0.04);
79   VERIFY_IS_APPROX(ab.coeff(5,1),  0.05);
80   VERIFY_IS_APPROX(ab.coeff(0,6), -0.08);
81   VERIFY_IS_APPROX(ab.coeff(2,7),  0.10);
82   VERIFY_IS_APPROX(ab.coeff(6,8),  0.12);
83   VERIFY_IS_APPROX(ab.coeff(8,9), -0.15);
84 }
85 
86 
EIGEN_DECLARE_TEST(kronecker_product)87 EIGEN_DECLARE_TEST(kronecker_product)
88 {
89   // DM = dense matrix; SM = sparse matrix
90 
91   Matrix<double, 2, 3> DM_a;
92   SparseMatrix<double> SM_a(2,3);
93   SM_a.insert(0,0) = DM_a.coeffRef(0,0) = -0.4461540300782201;
94   SM_a.insert(0,1) = DM_a.coeffRef(0,1) = -0.8057364375283049;
95   SM_a.insert(0,2) = DM_a.coeffRef(0,2) =  0.3896572459516341;
96   SM_a.insert(1,0) = DM_a.coeffRef(1,0) = -0.9076572187376921;
97   SM_a.insert(1,1) = DM_a.coeffRef(1,1) =  0.6469156566545853;
98   SM_a.insert(1,2) = DM_a.coeffRef(1,2) = -0.3658010398782789;
99 
100   MatrixXd             DM_b(3,2);
101   SparseMatrix<double> SM_b(3,2);
102   SM_b.insert(0,0) = DM_b.coeffRef(0,0) =  0.9004440976767099;
103   SM_b.insert(0,1) = DM_b.coeffRef(0,1) = -0.2368830858139832;
104   SM_b.insert(1,0) = DM_b.coeffRef(1,0) = -0.9311078389941825;
105   SM_b.insert(1,1) = DM_b.coeffRef(1,1) =  0.5310335762980047;
106   SM_b.insert(2,0) = DM_b.coeffRef(2,0) = -0.1225112806872035;
107   SM_b.insert(2,1) = DM_b.coeffRef(2,1) =  0.5903998022741264;
108 
109   SparseMatrix<double,RowMajor> SM_row_a(SM_a), SM_row_b(SM_b);
110 
111   // test DM_fixedSize = kroneckerProduct(DM_block,DM)
112   Matrix<double, 6, 6> DM_fix_ab = kroneckerProduct(DM_a.topLeftCorner<2,3>(),DM_b);
113 
114   CALL_SUBTEST(check_kronecker_product(DM_fix_ab));
115   CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a.topLeftCorner<2,3>(),DM_b)));
116 
117   for(int i=0;i<DM_fix_ab.rows();++i)
118     for(int j=0;j<DM_fix_ab.cols();++j)
119        VERIFY_IS_APPROX(kroneckerProduct(DM_a,DM_b).coeff(i,j), DM_fix_ab(i,j));
120 
121   // test DM_block = kroneckerProduct(DM,DM)
122   MatrixXd DM_block_ab(10,15);
123   DM_block_ab.block<6,6>(2,5) = kroneckerProduct(DM_a,DM_b);
124   CALL_SUBTEST(check_kronecker_product(DM_block_ab.block<6,6>(2,5)));
125 
126   // test DM = kroneckerProduct(DM,DM)
127   MatrixXd DM_ab = kroneckerProduct(DM_a,DM_b);
128   CALL_SUBTEST(check_kronecker_product(DM_ab));
129   CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a,DM_b)));
130 
131   // test SM = kroneckerProduct(SM,DM)
132   SparseMatrix<double> SM_ab = kroneckerProduct(SM_a,DM_b);
133   CALL_SUBTEST(check_kronecker_product(SM_ab));
134   SparseMatrix<double,RowMajor> SM_ab2 = kroneckerProduct(SM_a,DM_b);
135   CALL_SUBTEST(check_kronecker_product(SM_ab2));
136   CALL_SUBTEST(check_kronecker_product(kroneckerProduct(SM_a,DM_b)));
137 
138   // test SM = kroneckerProduct(DM,SM)
139   SM_ab.setZero();
140   SM_ab.insert(0,0)=37.0;
141   SM_ab = kroneckerProduct(DM_a,SM_b);
142   CALL_SUBTEST(check_kronecker_product(SM_ab));
143   SM_ab2.setZero();
144   SM_ab2.insert(0,0)=37.0;
145   SM_ab2 = kroneckerProduct(DM_a,SM_b);
146   CALL_SUBTEST(check_kronecker_product(SM_ab2));
147   CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a,SM_b)));
148 
149   // test SM = kroneckerProduct(SM,SM)
150   SM_ab.resize(2,33);
151   SM_ab.insert(0,0)=37.0;
152   SM_ab = kroneckerProduct(SM_a,SM_b);
153   CALL_SUBTEST(check_kronecker_product(SM_ab));
154   SM_ab2.resize(5,11);
155   SM_ab2.insert(0,0)=37.0;
156   SM_ab2 = kroneckerProduct(SM_a,SM_b);
157   CALL_SUBTEST(check_kronecker_product(SM_ab2));
158   CALL_SUBTEST(check_kronecker_product(kroneckerProduct(SM_a,SM_b)));
159 
160   // test SM = kroneckerProduct(SM,SM) with sparse pattern
161   SM_a.resize(4,5);
162   SM_b.resize(3,2);
163   SM_a.resizeNonZeros(0);
164   SM_b.resizeNonZeros(0);
165   SM_a.insert(1,0) = -0.1;
166   SM_a.insert(0,3) = -0.2;
167   SM_a.insert(2,4) =  0.3;
168   SM_a.finalize();
169 
170   SM_b.insert(0,0) =  0.4;
171   SM_b.insert(2,1) = -0.5;
172   SM_b.finalize();
173   SM_ab.resize(1,1);
174   SM_ab.insert(0,0)=37.0;
175   SM_ab = kroneckerProduct(SM_a,SM_b);
176   CALL_SUBTEST(check_sparse_kronecker_product(SM_ab));
177 
178   // test dimension of result of DM = kroneckerProduct(DM,DM)
179   MatrixXd DM_a2(2,1);
180   MatrixXd DM_b2(5,4);
181   MatrixXd DM_ab2 = kroneckerProduct(DM_a2,DM_b2);
182   CALL_SUBTEST(check_dimension(DM_ab2,2*5,1*4));
183   DM_a2.resize(10,9);
184   DM_b2.resize(4,8);
185   DM_ab2 = kroneckerProduct(DM_a2,DM_b2);
186   CALL_SUBTEST(check_dimension(DM_ab2,10*4,9*8));
187 
188   for(int i = 0; i < g_repeat; i++)
189   {
190     double density = Eigen::internal::random<double>(0.01,0.5);
191     int ra = Eigen::internal::random<int>(1,50);
192     int ca = Eigen::internal::random<int>(1,50);
193     int rb = Eigen::internal::random<int>(1,50);
194     int cb = Eigen::internal::random<int>(1,50);
195     SparseMatrix<float,ColMajor> sA(ra,ca), sB(rb,cb), sC;
196     SparseMatrix<float,RowMajor> sC2;
197     MatrixXf dA(ra,ca), dB(rb,cb), dC;
198     initSparse(density, dA, sA);
199     initSparse(density, dB, sB);
200 
201     sC = kroneckerProduct(sA,sB);
202     dC = kroneckerProduct(dA,dB);
203     VERIFY_IS_APPROX(MatrixXf(sC),dC);
204 
205     sC = kroneckerProduct(sA.transpose(),sB);
206     dC = kroneckerProduct(dA.transpose(),dB);
207     VERIFY_IS_APPROX(MatrixXf(sC),dC);
208 
209     sC = kroneckerProduct(sA.transpose(),sB.transpose());
210     dC = kroneckerProduct(dA.transpose(),dB.transpose());
211     VERIFY_IS_APPROX(MatrixXf(sC),dC);
212 
213     sC = kroneckerProduct(sA,sB.transpose());
214     dC = kroneckerProduct(dA,dB.transpose());
215     VERIFY_IS_APPROX(MatrixXf(sC),dC);
216 
217     sC2 = kroneckerProduct(sA,sB);
218     dC = kroneckerProduct(dA,dB);
219     VERIFY_IS_APPROX(MatrixXf(sC2),dC);
220 
221     sC2 = kroneckerProduct(dA,sB);
222     dC = kroneckerProduct(dA,dB);
223     VERIFY_IS_APPROX(MatrixXf(sC2),dC);
224 
225     sC2 = kroneckerProduct(sA,dB);
226     dC = kroneckerProduct(dA,dB);
227     VERIFY_IS_APPROX(MatrixXf(sC2),dC);
228 
229     sC2 = kroneckerProduct(2*sA,sB);
230     dC = kroneckerProduct(2*dA,dB);
231     VERIFY_IS_APPROX(MatrixXf(sC2),dC);
232   }
233 }
234 
235 #endif
236 
237 #ifdef EIGEN_TEST_PART_2
238 
239 // simply check that for a dense kronecker product, sparse module is not needed
240 #include "main.h"
241 #include <Eigen/KroneckerProduct>
242 
EIGEN_DECLARE_TEST(kronecker_product)243 EIGEN_DECLARE_TEST(kronecker_product)
244 {
245   MatrixXd a(2,2), b(3,3), c;
246   a.setRandom();
247   b.setRandom();
248   c = kroneckerProduct(a,b);
249   VERIFY_IS_APPROX(c.block(3,3,3,3), a(1,1)*b);
250 }
251 
252 #endif
253