xref: /aosp_15_r20/external/libmonet/quantize/QuantizerWsmeans.java (revision 970e10460f970939fd510dd6ad3e0d65908272e3)
1*970e1046SAndroid Build Coastguard Worker /*
2*970e1046SAndroid Build Coastguard Worker  * Copyright 2021 Google LLC
3*970e1046SAndroid Build Coastguard Worker  *
4*970e1046SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*970e1046SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*970e1046SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*970e1046SAndroid Build Coastguard Worker  *
8*970e1046SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*970e1046SAndroid Build Coastguard Worker  *
10*970e1046SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*970e1046SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*970e1046SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*970e1046SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*970e1046SAndroid Build Coastguard Worker  * limitations under the License.
15*970e1046SAndroid Build Coastguard Worker  */
16*970e1046SAndroid Build Coastguard Worker 
17*970e1046SAndroid Build Coastguard Worker package com.google.ux.material.libmonet.quantize;
18*970e1046SAndroid Build Coastguard Worker 
19*970e1046SAndroid Build Coastguard Worker import static java.lang.Math.min;
20*970e1046SAndroid Build Coastguard Worker 
21*970e1046SAndroid Build Coastguard Worker import java.util.Arrays;
22*970e1046SAndroid Build Coastguard Worker import java.util.LinkedHashMap;
23*970e1046SAndroid Build Coastguard Worker import java.util.Map;
24*970e1046SAndroid Build Coastguard Worker import java.util.Random;
25*970e1046SAndroid Build Coastguard Worker 
26*970e1046SAndroid Build Coastguard Worker /**
27*970e1046SAndroid Build Coastguard Worker  * An image quantizer that improves on the speed of a standard K-Means algorithm by implementing
28*970e1046SAndroid Build Coastguard Worker  * several optimizations, including deduping identical pixels and a triangle inequality rule that
29*970e1046SAndroid Build Coastguard Worker  * reduces the number of comparisons needed to identify which cluster a point should be moved to.
30*970e1046SAndroid Build Coastguard Worker  *
31*970e1046SAndroid Build Coastguard Worker  * <p>Wsmeans stands for Weighted Square Means.
32*970e1046SAndroid Build Coastguard Worker  *
33*970e1046SAndroid Build Coastguard Worker  * <p>This algorithm was designed by M. Emre Celebi, and was found in their 2011 paper, Improving
34*970e1046SAndroid Build Coastguard Worker  * the Performance of K-Means for Color Quantization. https://arxiv.org/abs/1101.0395
35*970e1046SAndroid Build Coastguard Worker  */
36*970e1046SAndroid Build Coastguard Worker public final class QuantizerWsmeans {
QuantizerWsmeans()37*970e1046SAndroid Build Coastguard Worker   private QuantizerWsmeans() {}
38*970e1046SAndroid Build Coastguard Worker 
39*970e1046SAndroid Build Coastguard Worker   private static final class Distance implements Comparable<Distance> {
40*970e1046SAndroid Build Coastguard Worker     int index;
41*970e1046SAndroid Build Coastguard Worker     double distance;
42*970e1046SAndroid Build Coastguard Worker 
Distance()43*970e1046SAndroid Build Coastguard Worker     Distance() {
44*970e1046SAndroid Build Coastguard Worker       this.index = -1;
45*970e1046SAndroid Build Coastguard Worker       this.distance = -1;
46*970e1046SAndroid Build Coastguard Worker     }
47*970e1046SAndroid Build Coastguard Worker 
48*970e1046SAndroid Build Coastguard Worker     @Override
compareTo(Distance other)49*970e1046SAndroid Build Coastguard Worker     public int compareTo(Distance other) {
50*970e1046SAndroid Build Coastguard Worker       return ((Double) this.distance).compareTo(other.distance);
51*970e1046SAndroid Build Coastguard Worker     }
52*970e1046SAndroid Build Coastguard Worker   }
53*970e1046SAndroid Build Coastguard Worker 
54*970e1046SAndroid Build Coastguard Worker   private static final int MAX_ITERATIONS = 10;
55*970e1046SAndroid Build Coastguard Worker   private static final double MIN_MOVEMENT_DISTANCE = 3.0;
56*970e1046SAndroid Build Coastguard Worker 
57*970e1046SAndroid Build Coastguard Worker   /**
58*970e1046SAndroid Build Coastguard Worker    * Reduce the number of colors needed to represented the input, minimizing the difference between
59*970e1046SAndroid Build Coastguard Worker    * the original image and the recolored image.
60*970e1046SAndroid Build Coastguard Worker    *
61*970e1046SAndroid Build Coastguard Worker    * @param inputPixels Colors in ARGB format.
62*970e1046SAndroid Build Coastguard Worker    * @param startingClusters Defines the initial state of the quantizer. Passing an empty array is
63*970e1046SAndroid Build Coastguard Worker    *     fine, the implementation will create its own initial state that leads to reproducible
64*970e1046SAndroid Build Coastguard Worker    *     results for the same inputs. Passing an array that is the result of Wu quantization leads
65*970e1046SAndroid Build Coastguard Worker    *     to higher quality results.
66*970e1046SAndroid Build Coastguard Worker    * @param maxColors The number of colors to divide the image into. A lower number of colors may be
67*970e1046SAndroid Build Coastguard Worker    *     returned.
68*970e1046SAndroid Build Coastguard Worker    * @return Map with keys of colors in ARGB format, values of how many of the input pixels belong
69*970e1046SAndroid Build Coastguard Worker    *     to the color.
70*970e1046SAndroid Build Coastguard Worker    */
quantize( int[] inputPixels, int[] startingClusters, int maxColors)71*970e1046SAndroid Build Coastguard Worker   public static Map<Integer, Integer> quantize(
72*970e1046SAndroid Build Coastguard Worker       int[] inputPixels, int[] startingClusters, int maxColors) {
73*970e1046SAndroid Build Coastguard Worker     // Uses a seeded random number generator to ensure consistent results.
74*970e1046SAndroid Build Coastguard Worker     Random random = new Random(0x42688);
75*970e1046SAndroid Build Coastguard Worker 
76*970e1046SAndroid Build Coastguard Worker     Map<Integer, Integer> pixelToCount = new LinkedHashMap<>();
77*970e1046SAndroid Build Coastguard Worker     double[][] points = new double[inputPixels.length][];
78*970e1046SAndroid Build Coastguard Worker     int[] pixels = new int[inputPixels.length];
79*970e1046SAndroid Build Coastguard Worker     PointProvider pointProvider = new PointProviderLab();
80*970e1046SAndroid Build Coastguard Worker 
81*970e1046SAndroid Build Coastguard Worker     int pointCount = 0;
82*970e1046SAndroid Build Coastguard Worker     for (int i = 0; i < inputPixels.length; i++) {
83*970e1046SAndroid Build Coastguard Worker       int inputPixel = inputPixels[i];
84*970e1046SAndroid Build Coastguard Worker       Integer pixelCount = pixelToCount.get(inputPixel);
85*970e1046SAndroid Build Coastguard Worker       if (pixelCount == null) {
86*970e1046SAndroid Build Coastguard Worker         points[pointCount] = pointProvider.fromInt(inputPixel);
87*970e1046SAndroid Build Coastguard Worker         pixels[pointCount] = inputPixel;
88*970e1046SAndroid Build Coastguard Worker         pointCount++;
89*970e1046SAndroid Build Coastguard Worker 
90*970e1046SAndroid Build Coastguard Worker         pixelToCount.put(inputPixel, 1);
91*970e1046SAndroid Build Coastguard Worker       } else {
92*970e1046SAndroid Build Coastguard Worker         pixelToCount.put(inputPixel, pixelCount + 1);
93*970e1046SAndroid Build Coastguard Worker       }
94*970e1046SAndroid Build Coastguard Worker     }
95*970e1046SAndroid Build Coastguard Worker 
96*970e1046SAndroid Build Coastguard Worker     int[] counts = new int[pointCount];
97*970e1046SAndroid Build Coastguard Worker     for (int i = 0; i < pointCount; i++) {
98*970e1046SAndroid Build Coastguard Worker       int pixel = pixels[i];
99*970e1046SAndroid Build Coastguard Worker       int count = pixelToCount.get(pixel);
100*970e1046SAndroid Build Coastguard Worker       counts[i] = count;
101*970e1046SAndroid Build Coastguard Worker     }
102*970e1046SAndroid Build Coastguard Worker 
103*970e1046SAndroid Build Coastguard Worker     int clusterCount = min(maxColors, pointCount);
104*970e1046SAndroid Build Coastguard Worker     if (startingClusters.length != 0) {
105*970e1046SAndroid Build Coastguard Worker       clusterCount = min(clusterCount, startingClusters.length);
106*970e1046SAndroid Build Coastguard Worker     }
107*970e1046SAndroid Build Coastguard Worker 
108*970e1046SAndroid Build Coastguard Worker     double[][] clusters = new double[clusterCount][];
109*970e1046SAndroid Build Coastguard Worker     int clustersCreated = 0;
110*970e1046SAndroid Build Coastguard Worker     for (int i = 0; i < startingClusters.length; i++) {
111*970e1046SAndroid Build Coastguard Worker       clusters[i] = pointProvider.fromInt(startingClusters[i]);
112*970e1046SAndroid Build Coastguard Worker       clustersCreated++;
113*970e1046SAndroid Build Coastguard Worker     }
114*970e1046SAndroid Build Coastguard Worker 
115*970e1046SAndroid Build Coastguard Worker     int additionalClustersNeeded = clusterCount - clustersCreated;
116*970e1046SAndroid Build Coastguard Worker     if (additionalClustersNeeded > 0) {
117*970e1046SAndroid Build Coastguard Worker       for (int i = 0; i < additionalClustersNeeded; i++) {}
118*970e1046SAndroid Build Coastguard Worker     }
119*970e1046SAndroid Build Coastguard Worker 
120*970e1046SAndroid Build Coastguard Worker     int[] clusterIndices = new int[pointCount];
121*970e1046SAndroid Build Coastguard Worker     for (int i = 0; i < pointCount; i++) {
122*970e1046SAndroid Build Coastguard Worker       clusterIndices[i] = random.nextInt(clusterCount);
123*970e1046SAndroid Build Coastguard Worker     }
124*970e1046SAndroid Build Coastguard Worker 
125*970e1046SAndroid Build Coastguard Worker     int[][] indexMatrix = new int[clusterCount][];
126*970e1046SAndroid Build Coastguard Worker     for (int i = 0; i < clusterCount; i++) {
127*970e1046SAndroid Build Coastguard Worker       indexMatrix[i] = new int[clusterCount];
128*970e1046SAndroid Build Coastguard Worker     }
129*970e1046SAndroid Build Coastguard Worker 
130*970e1046SAndroid Build Coastguard Worker     Distance[][] distanceToIndexMatrix = new Distance[clusterCount][];
131*970e1046SAndroid Build Coastguard Worker     for (int i = 0; i < clusterCount; i++) {
132*970e1046SAndroid Build Coastguard Worker       distanceToIndexMatrix[i] = new Distance[clusterCount];
133*970e1046SAndroid Build Coastguard Worker       for (int j = 0; j < clusterCount; j++) {
134*970e1046SAndroid Build Coastguard Worker         distanceToIndexMatrix[i][j] = new Distance();
135*970e1046SAndroid Build Coastguard Worker       }
136*970e1046SAndroid Build Coastguard Worker     }
137*970e1046SAndroid Build Coastguard Worker 
138*970e1046SAndroid Build Coastguard Worker     int[] pixelCountSums = new int[clusterCount];
139*970e1046SAndroid Build Coastguard Worker     for (int iteration = 0; iteration < MAX_ITERATIONS; iteration++) {
140*970e1046SAndroid Build Coastguard Worker       for (int i = 0; i < clusterCount; i++) {
141*970e1046SAndroid Build Coastguard Worker         for (int j = i + 1; j < clusterCount; j++) {
142*970e1046SAndroid Build Coastguard Worker           double distance = pointProvider.distance(clusters[i], clusters[j]);
143*970e1046SAndroid Build Coastguard Worker           distanceToIndexMatrix[j][i].distance = distance;
144*970e1046SAndroid Build Coastguard Worker           distanceToIndexMatrix[j][i].index = i;
145*970e1046SAndroid Build Coastguard Worker           distanceToIndexMatrix[i][j].distance = distance;
146*970e1046SAndroid Build Coastguard Worker           distanceToIndexMatrix[i][j].index = j;
147*970e1046SAndroid Build Coastguard Worker         }
148*970e1046SAndroid Build Coastguard Worker         Arrays.sort(distanceToIndexMatrix[i]);
149*970e1046SAndroid Build Coastguard Worker         for (int j = 0; j < clusterCount; j++) {
150*970e1046SAndroid Build Coastguard Worker           indexMatrix[i][j] = distanceToIndexMatrix[i][j].index;
151*970e1046SAndroid Build Coastguard Worker         }
152*970e1046SAndroid Build Coastguard Worker       }
153*970e1046SAndroid Build Coastguard Worker 
154*970e1046SAndroid Build Coastguard Worker       int pointsMoved = 0;
155*970e1046SAndroid Build Coastguard Worker       for (int i = 0; i < pointCount; i++) {
156*970e1046SAndroid Build Coastguard Worker         double[] point = points[i];
157*970e1046SAndroid Build Coastguard Worker         int previousClusterIndex = clusterIndices[i];
158*970e1046SAndroid Build Coastguard Worker         double[] previousCluster = clusters[previousClusterIndex];
159*970e1046SAndroid Build Coastguard Worker         double previousDistance = pointProvider.distance(point, previousCluster);
160*970e1046SAndroid Build Coastguard Worker 
161*970e1046SAndroid Build Coastguard Worker         double minimumDistance = previousDistance;
162*970e1046SAndroid Build Coastguard Worker         int newClusterIndex = -1;
163*970e1046SAndroid Build Coastguard Worker         for (int j = 0; j < clusterCount; j++) {
164*970e1046SAndroid Build Coastguard Worker           if (distanceToIndexMatrix[previousClusterIndex][j].distance >= 4 * previousDistance) {
165*970e1046SAndroid Build Coastguard Worker             continue;
166*970e1046SAndroid Build Coastguard Worker           }
167*970e1046SAndroid Build Coastguard Worker           double distance = pointProvider.distance(point, clusters[j]);
168*970e1046SAndroid Build Coastguard Worker           if (distance < minimumDistance) {
169*970e1046SAndroid Build Coastguard Worker             minimumDistance = distance;
170*970e1046SAndroid Build Coastguard Worker             newClusterIndex = j;
171*970e1046SAndroid Build Coastguard Worker           }
172*970e1046SAndroid Build Coastguard Worker         }
173*970e1046SAndroid Build Coastguard Worker         if (newClusterIndex != -1) {
174*970e1046SAndroid Build Coastguard Worker           double distanceChange =
175*970e1046SAndroid Build Coastguard Worker               Math.abs(Math.sqrt(minimumDistance) - Math.sqrt(previousDistance));
176*970e1046SAndroid Build Coastguard Worker           if (distanceChange > MIN_MOVEMENT_DISTANCE) {
177*970e1046SAndroid Build Coastguard Worker             pointsMoved++;
178*970e1046SAndroid Build Coastguard Worker             clusterIndices[i] = newClusterIndex;
179*970e1046SAndroid Build Coastguard Worker           }
180*970e1046SAndroid Build Coastguard Worker         }
181*970e1046SAndroid Build Coastguard Worker       }
182*970e1046SAndroid Build Coastguard Worker 
183*970e1046SAndroid Build Coastguard Worker       if (pointsMoved == 0 && iteration != 0) {
184*970e1046SAndroid Build Coastguard Worker         break;
185*970e1046SAndroid Build Coastguard Worker       }
186*970e1046SAndroid Build Coastguard Worker 
187*970e1046SAndroid Build Coastguard Worker       double[] componentASums = new double[clusterCount];
188*970e1046SAndroid Build Coastguard Worker       double[] componentBSums = new double[clusterCount];
189*970e1046SAndroid Build Coastguard Worker       double[] componentCSums = new double[clusterCount];
190*970e1046SAndroid Build Coastguard Worker       Arrays.fill(pixelCountSums, 0);
191*970e1046SAndroid Build Coastguard Worker       for (int i = 0; i < pointCount; i++) {
192*970e1046SAndroid Build Coastguard Worker         int clusterIndex = clusterIndices[i];
193*970e1046SAndroid Build Coastguard Worker         double[] point = points[i];
194*970e1046SAndroid Build Coastguard Worker         int count = counts[i];
195*970e1046SAndroid Build Coastguard Worker         pixelCountSums[clusterIndex] += count;
196*970e1046SAndroid Build Coastguard Worker         componentASums[clusterIndex] += (point[0] * count);
197*970e1046SAndroid Build Coastguard Worker         componentBSums[clusterIndex] += (point[1] * count);
198*970e1046SAndroid Build Coastguard Worker         componentCSums[clusterIndex] += (point[2] * count);
199*970e1046SAndroid Build Coastguard Worker       }
200*970e1046SAndroid Build Coastguard Worker 
201*970e1046SAndroid Build Coastguard Worker       for (int i = 0; i < clusterCount; i++) {
202*970e1046SAndroid Build Coastguard Worker         int count = pixelCountSums[i];
203*970e1046SAndroid Build Coastguard Worker         if (count == 0) {
204*970e1046SAndroid Build Coastguard Worker           clusters[i] = new double[] {0., 0., 0.};
205*970e1046SAndroid Build Coastguard Worker           continue;
206*970e1046SAndroid Build Coastguard Worker         }
207*970e1046SAndroid Build Coastguard Worker         double a = componentASums[i] / count;
208*970e1046SAndroid Build Coastguard Worker         double b = componentBSums[i] / count;
209*970e1046SAndroid Build Coastguard Worker         double c = componentCSums[i] / count;
210*970e1046SAndroid Build Coastguard Worker         clusters[i][0] = a;
211*970e1046SAndroid Build Coastguard Worker         clusters[i][1] = b;
212*970e1046SAndroid Build Coastguard Worker         clusters[i][2] = c;
213*970e1046SAndroid Build Coastguard Worker       }
214*970e1046SAndroid Build Coastguard Worker     }
215*970e1046SAndroid Build Coastguard Worker 
216*970e1046SAndroid Build Coastguard Worker     Map<Integer, Integer> argbToPopulation = new LinkedHashMap<>();
217*970e1046SAndroid Build Coastguard Worker     for (int i = 0; i < clusterCount; i++) {
218*970e1046SAndroid Build Coastguard Worker       int count = pixelCountSums[i];
219*970e1046SAndroid Build Coastguard Worker       if (count == 0) {
220*970e1046SAndroid Build Coastguard Worker         continue;
221*970e1046SAndroid Build Coastguard Worker       }
222*970e1046SAndroid Build Coastguard Worker 
223*970e1046SAndroid Build Coastguard Worker       int possibleNewCluster = pointProvider.toInt(clusters[i]);
224*970e1046SAndroid Build Coastguard Worker       if (argbToPopulation.containsKey(possibleNewCluster)) {
225*970e1046SAndroid Build Coastguard Worker         continue;
226*970e1046SAndroid Build Coastguard Worker       }
227*970e1046SAndroid Build Coastguard Worker 
228*970e1046SAndroid Build Coastguard Worker       argbToPopulation.put(possibleNewCluster, count);
229*970e1046SAndroid Build Coastguard Worker     }
230*970e1046SAndroid Build Coastguard Worker 
231*970e1046SAndroid Build Coastguard Worker     return argbToPopulation;
232*970e1046SAndroid Build Coastguard Worker   }
233*970e1046SAndroid Build Coastguard Worker }
234