1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 package org.apache.commons.math3.distribution;
18 
19 import org.apache.commons.math3.exception.DimensionMismatchException;
20 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
21 import org.apache.commons.math3.linear.EigenDecomposition;
22 import org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException;
23 import org.apache.commons.math3.linear.RealMatrix;
24 import org.apache.commons.math3.linear.SingularMatrixException;
25 import org.apache.commons.math3.random.RandomGenerator;
26 import org.apache.commons.math3.random.Well19937c;
27 import org.apache.commons.math3.util.FastMath;
28 import org.apache.commons.math3.util.MathArrays;
29 
30 /**
31  * Implementation of the multivariate normal (Gaussian) distribution.
32  *
33  * @see <a href="http://en.wikipedia.org/wiki/Multivariate_normal_distribution">Multivariate normal
34  *     distribution (Wikipedia)</a>
35  * @see <a href="http://mathworld.wolfram.com/MultivariateNormalDistribution.html">Multivariate
36  *     normal distribution (MathWorld)</a>
37  * @since 3.1
38  */
39 public class MultivariateNormalDistribution extends AbstractMultivariateRealDistribution {
40     /** Vector of means. */
41     private final double[] means;
42 
43     /** Covariance matrix. */
44     private final RealMatrix covarianceMatrix;
45 
46     /** The matrix inverse of the covariance matrix. */
47     private final RealMatrix covarianceMatrixInverse;
48 
49     /** The determinant of the covariance matrix. */
50     private final double covarianceMatrixDeterminant;
51 
52     /** Matrix used in computation of samples. */
53     private final RealMatrix samplingMatrix;
54 
55     /**
56      * Creates a multivariate normal distribution with the given mean vector and covariance matrix.
57      * <br>
58      * The number of dimensions is equal to the length of the mean vector and to the number of rows
59      * and columns of the covariance matrix. It is frequently written as "p" in formulae.
60      *
61      * <p><b>Note:</b> this constructor will implicitly create an instance of {@link Well19937c} as
62      * random generator to be used for sampling only (see {@link #sample()} and {@link
63      * #sample(int)}). In case no sampling is needed for the created distribution, it is advised to
64      * pass {@code null} as random generator via the appropriate constructors to avoid the
65      * additional initialisation overhead.
66      *
67      * @param means Vector of means.
68      * @param covariances Covariance matrix.
69      * @throws DimensionMismatchException if the arrays length are inconsistent.
70      * @throws SingularMatrixException if the eigenvalue decomposition cannot be performed on the
71      *     provided covariance matrix.
72      * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is negative.
73      */
MultivariateNormalDistribution(final double[] means, final double[][] covariances)74     public MultivariateNormalDistribution(final double[] means, final double[][] covariances)
75             throws SingularMatrixException,
76                     DimensionMismatchException,
77                     NonPositiveDefiniteMatrixException {
78         this(new Well19937c(), means, covariances);
79     }
80 
81     /**
82      * Creates a multivariate normal distribution with the given mean vector and covariance matrix.
83      * <br>
84      * The number of dimensions is equal to the length of the mean vector and to the number of rows
85      * and columns of the covariance matrix. It is frequently written as "p" in formulae.
86      *
87      * @param rng Random Number Generator.
88      * @param means Vector of means.
89      * @param covariances Covariance matrix.
90      * @throws DimensionMismatchException if the arrays length are inconsistent.
91      * @throws SingularMatrixException if the eigenvalue decomposition cannot be performed on the
92      *     provided covariance matrix.
93      * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is negative.
94      */
MultivariateNormalDistribution( RandomGenerator rng, final double[] means, final double[][] covariances)95     public MultivariateNormalDistribution(
96             RandomGenerator rng, final double[] means, final double[][] covariances)
97             throws SingularMatrixException,
98                     DimensionMismatchException,
99                     NonPositiveDefiniteMatrixException {
100         super(rng, means.length);
101 
102         final int dim = means.length;
103 
104         if (covariances.length != dim) {
105             throw new DimensionMismatchException(covariances.length, dim);
106         }
107 
108         for (int i = 0; i < dim; i++) {
109             if (dim != covariances[i].length) {
110                 throw new DimensionMismatchException(covariances[i].length, dim);
111             }
112         }
113 
114         this.means = MathArrays.copyOf(means);
115 
116         covarianceMatrix = new Array2DRowRealMatrix(covariances);
117 
118         // Covariance matrix eigen decomposition.
119         final EigenDecomposition covMatDec = new EigenDecomposition(covarianceMatrix);
120 
121         // Compute and store the inverse.
122         covarianceMatrixInverse = covMatDec.getSolver().getInverse();
123         // Compute and store the determinant.
124         covarianceMatrixDeterminant = covMatDec.getDeterminant();
125 
126         // Eigenvalues of the covariance matrix.
127         final double[] covMatEigenvalues = covMatDec.getRealEigenvalues();
128 
129         for (int i = 0; i < covMatEigenvalues.length; i++) {
130             if (covMatEigenvalues[i] < 0) {
131                 throw new NonPositiveDefiniteMatrixException(covMatEigenvalues[i], i, 0);
132             }
133         }
134 
135         // Matrix where each column is an eigenvector of the covariance matrix.
136         final Array2DRowRealMatrix covMatEigenvectors = new Array2DRowRealMatrix(dim, dim);
137         for (int v = 0; v < dim; v++) {
138             final double[] evec = covMatDec.getEigenvector(v).toArray();
139             covMatEigenvectors.setColumn(v, evec);
140         }
141 
142         final RealMatrix tmpMatrix = covMatEigenvectors.transpose();
143 
144         // Scale each eigenvector by the square root of its eigenvalue.
145         for (int row = 0; row < dim; row++) {
146             final double factor = FastMath.sqrt(covMatEigenvalues[row]);
147             for (int col = 0; col < dim; col++) {
148                 tmpMatrix.multiplyEntry(row, col, factor);
149             }
150         }
151 
152         samplingMatrix = covMatEigenvectors.multiply(tmpMatrix);
153     }
154 
155     /**
156      * Gets the mean vector.
157      *
158      * @return the mean vector.
159      */
getMeans()160     public double[] getMeans() {
161         return MathArrays.copyOf(means);
162     }
163 
164     /**
165      * Gets the covariance matrix.
166      *
167      * @return the covariance matrix.
168      */
getCovariances()169     public RealMatrix getCovariances() {
170         return covarianceMatrix.copy();
171     }
172 
173     /** {@inheritDoc} */
density(final double[] vals)174     public double density(final double[] vals) throws DimensionMismatchException {
175         final int dim = getDimension();
176         if (vals.length != dim) {
177             throw new DimensionMismatchException(vals.length, dim);
178         }
179 
180         return FastMath.pow(2 * FastMath.PI, -0.5 * dim)
181                 * FastMath.pow(covarianceMatrixDeterminant, -0.5)
182                 * getExponentTerm(vals);
183     }
184 
185     /**
186      * Gets the square root of each element on the diagonal of the covariance matrix.
187      *
188      * @return the standard deviations.
189      */
getStandardDeviations()190     public double[] getStandardDeviations() {
191         final int dim = getDimension();
192         final double[] std = new double[dim];
193         final double[][] s = covarianceMatrix.getData();
194         for (int i = 0; i < dim; i++) {
195             std[i] = FastMath.sqrt(s[i][i]);
196         }
197         return std;
198     }
199 
200     /** {@inheritDoc} */
201     @Override
sample()202     public double[] sample() {
203         final int dim = getDimension();
204         final double[] normalVals = new double[dim];
205 
206         for (int i = 0; i < dim; i++) {
207             normalVals[i] = random.nextGaussian();
208         }
209 
210         final double[] vals = samplingMatrix.operate(normalVals);
211 
212         for (int i = 0; i < dim; i++) {
213             vals[i] += means[i];
214         }
215 
216         return vals;
217     }
218 
219     /**
220      * Computes the term used in the exponent (see definition of the distribution).
221      *
222      * @param values Values at which to compute density.
223      * @return the multiplication factor of density calculations.
224      */
getExponentTerm(final double[] values)225     private double getExponentTerm(final double[] values) {
226         final double[] centered = new double[values.length];
227         for (int i = 0; i < centered.length; i++) {
228             centered[i] = values[i] - getMeans()[i];
229         }
230         final double[] preMultiplied = covarianceMatrixInverse.preMultiply(centered);
231         double sum = 0;
232         for (int i = 0; i < preMultiplied.length; i++) {
233             sum += preMultiplied[i] * centered[i];
234         }
235         return FastMath.exp(-0.5 * sum);
236     }
237 }
238