xref: /aosp_15_r20/external/webrtc/rtc_tools/frame_analyzer/linear_least_squares.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "rtc_tools/frame_analyzer/linear_least_squares.h"
12 
13 #include <math.h>
14 
15 #include <cstdint>
16 #include <cstdlib>
17 #include <functional>
18 #include <numeric>
19 #include <type_traits>
20 #include <utility>
21 
22 #include "rtc_base/checks.h"
23 #include "rtc_base/logging.h"
24 
25 namespace webrtc {
26 namespace test {
27 
28 template <class T>
29 using Matrix = std::valarray<std::valarray<T>>;
30 
31 namespace {
32 
33 template <typename R, typename T>
DotProduct(const std::valarray<T> & a,const std::valarray<T> & b)34 R DotProduct(const std::valarray<T>& a, const std::valarray<T>& b) {
35   RTC_CHECK_EQ(a.size(), b.size());
36   return std::inner_product(std::begin(a), std::end(a), std::begin(b), R(0));
37 }
38 
39 // Calculates a^T * b.
40 template <typename R, typename T>
MatrixMultiply(const Matrix<T> & a,const Matrix<T> & b)41 Matrix<R> MatrixMultiply(const Matrix<T>& a, const Matrix<T>& b) {
42   Matrix<R> result(std::valarray<R>(a.size()), b.size());
43   for (size_t i = 0; i < a.size(); ++i) {
44     for (size_t j = 0; j < b.size(); ++j)
45       result[j][i] = DotProduct<R>(a[i], b[j]);
46   }
47 
48   return result;
49 }
50 
51 template <typename T>
Transpose(const Matrix<T> & matrix)52 Matrix<T> Transpose(const Matrix<T>& matrix) {
53   if (matrix.size() == 0)
54     return Matrix<T>();
55   const size_t rows = matrix.size();
56   const size_t columns = matrix[0].size();
57   Matrix<T> result(std::valarray<T>(rows), columns);
58 
59   for (size_t i = 0; i < rows; ++i) {
60     for (size_t j = 0; j < columns; ++j)
61       result[j][i] = matrix[i][j];
62   }
63 
64   return result;
65 }
66 
67 // Convert valarray from type T to type R.
68 template <typename R, typename T>
ConvertTo(const std::valarray<T> & v)69 std::valarray<R> ConvertTo(const std::valarray<T>& v) {
70   std::valarray<R> result(v.size());
71   for (size_t i = 0; i < v.size(); ++i)
72     result[i] = static_cast<R>(v[i]);
73   return result;
74 }
75 
76 // Convert valarray Matrix from type T to type R.
77 template <typename R, typename T>
ConvertTo(const Matrix<T> & mat)78 Matrix<R> ConvertTo(const Matrix<T>& mat) {
79   Matrix<R> result(mat.size());
80   for (size_t i = 0; i < mat.size(); ++i)
81     result[i] = ConvertTo<R>(mat[i]);
82   return result;
83 }
84 
85 // Convert from valarray Matrix back to the more conventional std::vector.
86 template <typename T>
ToVectorMatrix(const Matrix<T> & m)87 std::vector<std::vector<T>> ToVectorMatrix(const Matrix<T>& m) {
88   std::vector<std::vector<T>> result;
89   for (const std::valarray<T>& v : m)
90     result.emplace_back(std::begin(v), std::end(v));
91   return result;
92 }
93 
94 // Create a valarray Matrix from a conventional std::vector.
95 template <typename T>
FromVectorMatrix(const std::vector<std::vector<T>> & mat)96 Matrix<T> FromVectorMatrix(const std::vector<std::vector<T>>& mat) {
97   Matrix<T> result(mat.size());
98   for (size_t i = 0; i < mat.size(); ++i)
99     result[i] = std::valarray<T>(mat[i].data(), mat[i].size());
100   return result;
101 }
102 
103 // Returns `matrix_to_invert`^-1 * `right_hand_matrix`. `matrix_to_invert` must
104 // have square size.
GaussianElimination(Matrix<double> matrix_to_invert,Matrix<double> right_hand_matrix)105 Matrix<double> GaussianElimination(Matrix<double> matrix_to_invert,
106                                    Matrix<double> right_hand_matrix) {
107   // `n` is the width/height of `matrix_to_invert`.
108   const size_t n = matrix_to_invert.size();
109   // Make sure `matrix_to_invert` has square size.
110   for (const std::valarray<double>& column : matrix_to_invert)
111     RTC_CHECK_EQ(n, column.size());
112   // Make sure `right_hand_matrix` has correct size.
113   for (const std::valarray<double>& column : right_hand_matrix)
114     RTC_CHECK_EQ(n, column.size());
115 
116   // Transpose the matrices before and after so that we can perform Gaussian
117   // elimination on the columns instead of the rows, since that is easier with
118   // our representation.
119   matrix_to_invert = Transpose(matrix_to_invert);
120   right_hand_matrix = Transpose(right_hand_matrix);
121 
122   // Loop over the diagonal of `matrix_to_invert` and perform column reduction.
123   // Column reduction is a sequence of elementary column operations that is
124   // performed on both `matrix_to_invert` and `right_hand_matrix` until
125   // `matrix_to_invert` has been transformed to the identity matrix.
126   for (size_t diagonal_index = 0; diagonal_index < n; ++diagonal_index) {
127     // Make sure the diagonal element has the highest absolute value by
128     // swapping columns if necessary.
129     for (size_t column = diagonal_index + 1; column < n; ++column) {
130       if (std::abs(matrix_to_invert[column][diagonal_index]) >
131           std::abs(matrix_to_invert[diagonal_index][diagonal_index])) {
132         std::swap(matrix_to_invert[column], matrix_to_invert[diagonal_index]);
133         std::swap(right_hand_matrix[column], right_hand_matrix[diagonal_index]);
134       }
135     }
136 
137     // Reduce the diagonal element to be 1, by dividing the column with that
138     // value. If the diagonal element is 0, it means the system of equations has
139     // many solutions, and in that case we will return an arbitrary solution.
140     if (matrix_to_invert[diagonal_index][diagonal_index] == 0.0) {
141       RTC_LOG(LS_WARNING) << "Matrix is not invertible, ignoring.";
142       continue;
143     }
144     const double diagonal_element =
145         matrix_to_invert[diagonal_index][diagonal_index];
146     matrix_to_invert[diagonal_index] /= diagonal_element;
147     right_hand_matrix[diagonal_index] /= diagonal_element;
148 
149     // Eliminate the other entries in row `diagonal_index` by making them zero.
150     for (size_t column = 0; column < n; ++column) {
151       if (column == diagonal_index)
152         continue;
153       const double row_element = matrix_to_invert[column][diagonal_index];
154       matrix_to_invert[column] -=
155           row_element * matrix_to_invert[diagonal_index];
156       right_hand_matrix[column] -=
157           row_element * right_hand_matrix[diagonal_index];
158     }
159   }
160 
161   // Transpose the result before returning it, explained in comment above.
162   return Transpose(right_hand_matrix);
163 }
164 
165 }  // namespace
166 
167 IncrementalLinearLeastSquares::IncrementalLinearLeastSquares() = default;
168 IncrementalLinearLeastSquares::~IncrementalLinearLeastSquares() = default;
169 
AddObservations(const std::vector<std::vector<uint8_t>> & x,const std::vector<std::vector<uint8_t>> & y)170 void IncrementalLinearLeastSquares::AddObservations(
171     const std::vector<std::vector<uint8_t>>& x,
172     const std::vector<std::vector<uint8_t>>& y) {
173   if (x.empty() || y.empty())
174     return;
175   // Make sure all columns are the same size.
176   const size_t n = x[0].size();
177   for (const std::vector<uint8_t>& column : x)
178     RTC_CHECK_EQ(n, column.size());
179   for (const std::vector<uint8_t>& column : y)
180     RTC_CHECK_EQ(n, column.size());
181 
182   // We will multiply the uint8_t values together, so we need to expand to a
183   // type that can safely store those values, i.e. uint16_t.
184   const Matrix<uint16_t> unpacked_x = ConvertTo<uint16_t>(FromVectorMatrix(x));
185   const Matrix<uint16_t> unpacked_y = ConvertTo<uint16_t>(FromVectorMatrix(y));
186 
187   const Matrix<uint64_t> xx = MatrixMultiply<uint64_t>(unpacked_x, unpacked_x);
188   const Matrix<uint64_t> xy = MatrixMultiply<uint64_t>(unpacked_x, unpacked_y);
189   if (sum_xx && sum_xy) {
190     *sum_xx += xx;
191     *sum_xy += xy;
192   } else {
193     sum_xx = xx;
194     sum_xy = xy;
195   }
196 }
197 
198 std::vector<std::vector<double>>
GetBestSolution() const199 IncrementalLinearLeastSquares::GetBestSolution() const {
200   RTC_CHECK(sum_xx && sum_xy) << "No observations have been added";
201   return ToVectorMatrix(GaussianElimination(ConvertTo<double>(*sum_xx),
202                                             ConvertTo<double>(*sum_xy)));
203 }
204 
205 }  // namespace test
206 }  // namespace webrtc
207