1*970e1046SAndroid Build Coastguard Worker /* 2*970e1046SAndroid Build Coastguard Worker * Copyright 2021 Google LLC 3*970e1046SAndroid Build Coastguard Worker * 4*970e1046SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*970e1046SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*970e1046SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*970e1046SAndroid Build Coastguard Worker * 8*970e1046SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*970e1046SAndroid Build Coastguard Worker * 10*970e1046SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*970e1046SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*970e1046SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*970e1046SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*970e1046SAndroid Build Coastguard Worker * limitations under the License. 15*970e1046SAndroid Build Coastguard Worker */ 16*970e1046SAndroid Build Coastguard Worker 17*970e1046SAndroid Build Coastguard Worker package com.google.ux.material.libmonet.quantize; 18*970e1046SAndroid Build Coastguard Worker 19*970e1046SAndroid Build Coastguard Worker import static java.lang.Math.min; 20*970e1046SAndroid Build Coastguard Worker 21*970e1046SAndroid Build Coastguard Worker import java.util.Arrays; 22*970e1046SAndroid Build Coastguard Worker import java.util.LinkedHashMap; 23*970e1046SAndroid Build Coastguard Worker import java.util.Map; 24*970e1046SAndroid Build Coastguard Worker import java.util.Random; 25*970e1046SAndroid Build Coastguard Worker 26*970e1046SAndroid Build Coastguard Worker /** 27*970e1046SAndroid Build Coastguard Worker * An image quantizer that improves on the speed of a standard K-Means algorithm by implementing 28*970e1046SAndroid Build Coastguard Worker * several optimizations, including deduping identical pixels and a triangle inequality rule that 29*970e1046SAndroid Build Coastguard Worker * reduces the number of comparisons needed to identify which cluster a point should be moved to. 30*970e1046SAndroid Build Coastguard Worker * 31*970e1046SAndroid Build Coastguard Worker * <p>Wsmeans stands for Weighted Square Means. 32*970e1046SAndroid Build Coastguard Worker * 33*970e1046SAndroid Build Coastguard Worker * <p>This algorithm was designed by M. Emre Celebi, and was found in their 2011 paper, Improving 34*970e1046SAndroid Build Coastguard Worker * the Performance of K-Means for Color Quantization. https://arxiv.org/abs/1101.0395 35*970e1046SAndroid Build Coastguard Worker */ 36*970e1046SAndroid Build Coastguard Worker public final class QuantizerWsmeans { QuantizerWsmeans()37*970e1046SAndroid Build Coastguard Worker private QuantizerWsmeans() {} 38*970e1046SAndroid Build Coastguard Worker 39*970e1046SAndroid Build Coastguard Worker private static final class Distance implements Comparable<Distance> { 40*970e1046SAndroid Build Coastguard Worker int index; 41*970e1046SAndroid Build Coastguard Worker double distance; 42*970e1046SAndroid Build Coastguard Worker Distance()43*970e1046SAndroid Build Coastguard Worker Distance() { 44*970e1046SAndroid Build Coastguard Worker this.index = -1; 45*970e1046SAndroid Build Coastguard Worker this.distance = -1; 46*970e1046SAndroid Build Coastguard Worker } 47*970e1046SAndroid Build Coastguard Worker 48*970e1046SAndroid Build Coastguard Worker @Override compareTo(Distance other)49*970e1046SAndroid Build Coastguard Worker public int compareTo(Distance other) { 50*970e1046SAndroid Build Coastguard Worker return ((Double) this.distance).compareTo(other.distance); 51*970e1046SAndroid Build Coastguard Worker } 52*970e1046SAndroid Build Coastguard Worker } 53*970e1046SAndroid Build Coastguard Worker 54*970e1046SAndroid Build Coastguard Worker private static final int MAX_ITERATIONS = 10; 55*970e1046SAndroid Build Coastguard Worker private static final double MIN_MOVEMENT_DISTANCE = 3.0; 56*970e1046SAndroid Build Coastguard Worker 57*970e1046SAndroid Build Coastguard Worker /** 58*970e1046SAndroid Build Coastguard Worker * Reduce the number of colors needed to represented the input, minimizing the difference between 59*970e1046SAndroid Build Coastguard Worker * the original image and the recolored image. 60*970e1046SAndroid Build Coastguard Worker * 61*970e1046SAndroid Build Coastguard Worker * @param inputPixels Colors in ARGB format. 62*970e1046SAndroid Build Coastguard Worker * @param startingClusters Defines the initial state of the quantizer. Passing an empty array is 63*970e1046SAndroid Build Coastguard Worker * fine, the implementation will create its own initial state that leads to reproducible 64*970e1046SAndroid Build Coastguard Worker * results for the same inputs. Passing an array that is the result of Wu quantization leads 65*970e1046SAndroid Build Coastguard Worker * to higher quality results. 66*970e1046SAndroid Build Coastguard Worker * @param maxColors The number of colors to divide the image into. A lower number of colors may be 67*970e1046SAndroid Build Coastguard Worker * returned. 68*970e1046SAndroid Build Coastguard Worker * @return Map with keys of colors in ARGB format, values of how many of the input pixels belong 69*970e1046SAndroid Build Coastguard Worker * to the color. 70*970e1046SAndroid Build Coastguard Worker */ quantize( int[] inputPixels, int[] startingClusters, int maxColors)71*970e1046SAndroid Build Coastguard Worker public static Map<Integer, Integer> quantize( 72*970e1046SAndroid Build Coastguard Worker int[] inputPixels, int[] startingClusters, int maxColors) { 73*970e1046SAndroid Build Coastguard Worker // Uses a seeded random number generator to ensure consistent results. 74*970e1046SAndroid Build Coastguard Worker Random random = new Random(0x42688); 75*970e1046SAndroid Build Coastguard Worker 76*970e1046SAndroid Build Coastguard Worker Map<Integer, Integer> pixelToCount = new LinkedHashMap<>(); 77*970e1046SAndroid Build Coastguard Worker double[][] points = new double[inputPixels.length][]; 78*970e1046SAndroid Build Coastguard Worker int[] pixels = new int[inputPixels.length]; 79*970e1046SAndroid Build Coastguard Worker PointProvider pointProvider = new PointProviderLab(); 80*970e1046SAndroid Build Coastguard Worker 81*970e1046SAndroid Build Coastguard Worker int pointCount = 0; 82*970e1046SAndroid Build Coastguard Worker for (int i = 0; i < inputPixels.length; i++) { 83*970e1046SAndroid Build Coastguard Worker int inputPixel = inputPixels[i]; 84*970e1046SAndroid Build Coastguard Worker Integer pixelCount = pixelToCount.get(inputPixel); 85*970e1046SAndroid Build Coastguard Worker if (pixelCount == null) { 86*970e1046SAndroid Build Coastguard Worker points[pointCount] = pointProvider.fromInt(inputPixel); 87*970e1046SAndroid Build Coastguard Worker pixels[pointCount] = inputPixel; 88*970e1046SAndroid Build Coastguard Worker pointCount++; 89*970e1046SAndroid Build Coastguard Worker 90*970e1046SAndroid Build Coastguard Worker pixelToCount.put(inputPixel, 1); 91*970e1046SAndroid Build Coastguard Worker } else { 92*970e1046SAndroid Build Coastguard Worker pixelToCount.put(inputPixel, pixelCount + 1); 93*970e1046SAndroid Build Coastguard Worker } 94*970e1046SAndroid Build Coastguard Worker } 95*970e1046SAndroid Build Coastguard Worker 96*970e1046SAndroid Build Coastguard Worker int[] counts = new int[pointCount]; 97*970e1046SAndroid Build Coastguard Worker for (int i = 0; i < pointCount; i++) { 98*970e1046SAndroid Build Coastguard Worker int pixel = pixels[i]; 99*970e1046SAndroid Build Coastguard Worker int count = pixelToCount.get(pixel); 100*970e1046SAndroid Build Coastguard Worker counts[i] = count; 101*970e1046SAndroid Build Coastguard Worker } 102*970e1046SAndroid Build Coastguard Worker 103*970e1046SAndroid Build Coastguard Worker int clusterCount = min(maxColors, pointCount); 104*970e1046SAndroid Build Coastguard Worker if (startingClusters.length != 0) { 105*970e1046SAndroid Build Coastguard Worker clusterCount = min(clusterCount, startingClusters.length); 106*970e1046SAndroid Build Coastguard Worker } 107*970e1046SAndroid Build Coastguard Worker 108*970e1046SAndroid Build Coastguard Worker double[][] clusters = new double[clusterCount][]; 109*970e1046SAndroid Build Coastguard Worker int clustersCreated = 0; 110*970e1046SAndroid Build Coastguard Worker for (int i = 0; i < startingClusters.length; i++) { 111*970e1046SAndroid Build Coastguard Worker clusters[i] = pointProvider.fromInt(startingClusters[i]); 112*970e1046SAndroid Build Coastguard Worker clustersCreated++; 113*970e1046SAndroid Build Coastguard Worker } 114*970e1046SAndroid Build Coastguard Worker 115*970e1046SAndroid Build Coastguard Worker int additionalClustersNeeded = clusterCount - clustersCreated; 116*970e1046SAndroid Build Coastguard Worker if (additionalClustersNeeded > 0) { 117*970e1046SAndroid Build Coastguard Worker for (int i = 0; i < additionalClustersNeeded; i++) {} 118*970e1046SAndroid Build Coastguard Worker } 119*970e1046SAndroid Build Coastguard Worker 120*970e1046SAndroid Build Coastguard Worker int[] clusterIndices = new int[pointCount]; 121*970e1046SAndroid Build Coastguard Worker for (int i = 0; i < pointCount; i++) { 122*970e1046SAndroid Build Coastguard Worker clusterIndices[i] = random.nextInt(clusterCount); 123*970e1046SAndroid Build Coastguard Worker } 124*970e1046SAndroid Build Coastguard Worker 125*970e1046SAndroid Build Coastguard Worker int[][] indexMatrix = new int[clusterCount][]; 126*970e1046SAndroid Build Coastguard Worker for (int i = 0; i < clusterCount; i++) { 127*970e1046SAndroid Build Coastguard Worker indexMatrix[i] = new int[clusterCount]; 128*970e1046SAndroid Build Coastguard Worker } 129*970e1046SAndroid Build Coastguard Worker 130*970e1046SAndroid Build Coastguard Worker Distance[][] distanceToIndexMatrix = new Distance[clusterCount][]; 131*970e1046SAndroid Build Coastguard Worker for (int i = 0; i < clusterCount; i++) { 132*970e1046SAndroid Build Coastguard Worker distanceToIndexMatrix[i] = new Distance[clusterCount]; 133*970e1046SAndroid Build Coastguard Worker for (int j = 0; j < clusterCount; j++) { 134*970e1046SAndroid Build Coastguard Worker distanceToIndexMatrix[i][j] = new Distance(); 135*970e1046SAndroid Build Coastguard Worker } 136*970e1046SAndroid Build Coastguard Worker } 137*970e1046SAndroid Build Coastguard Worker 138*970e1046SAndroid Build Coastguard Worker int[] pixelCountSums = new int[clusterCount]; 139*970e1046SAndroid Build Coastguard Worker for (int iteration = 0; iteration < MAX_ITERATIONS; iteration++) { 140*970e1046SAndroid Build Coastguard Worker for (int i = 0; i < clusterCount; i++) { 141*970e1046SAndroid Build Coastguard Worker for (int j = i + 1; j < clusterCount; j++) { 142*970e1046SAndroid Build Coastguard Worker double distance = pointProvider.distance(clusters[i], clusters[j]); 143*970e1046SAndroid Build Coastguard Worker distanceToIndexMatrix[j][i].distance = distance; 144*970e1046SAndroid Build Coastguard Worker distanceToIndexMatrix[j][i].index = i; 145*970e1046SAndroid Build Coastguard Worker distanceToIndexMatrix[i][j].distance = distance; 146*970e1046SAndroid Build Coastguard Worker distanceToIndexMatrix[i][j].index = j; 147*970e1046SAndroid Build Coastguard Worker } 148*970e1046SAndroid Build Coastguard Worker Arrays.sort(distanceToIndexMatrix[i]); 149*970e1046SAndroid Build Coastguard Worker for (int j = 0; j < clusterCount; j++) { 150*970e1046SAndroid Build Coastguard Worker indexMatrix[i][j] = distanceToIndexMatrix[i][j].index; 151*970e1046SAndroid Build Coastguard Worker } 152*970e1046SAndroid Build Coastguard Worker } 153*970e1046SAndroid Build Coastguard Worker 154*970e1046SAndroid Build Coastguard Worker int pointsMoved = 0; 155*970e1046SAndroid Build Coastguard Worker for (int i = 0; i < pointCount; i++) { 156*970e1046SAndroid Build Coastguard Worker double[] point = points[i]; 157*970e1046SAndroid Build Coastguard Worker int previousClusterIndex = clusterIndices[i]; 158*970e1046SAndroid Build Coastguard Worker double[] previousCluster = clusters[previousClusterIndex]; 159*970e1046SAndroid Build Coastguard Worker double previousDistance = pointProvider.distance(point, previousCluster); 160*970e1046SAndroid Build Coastguard Worker 161*970e1046SAndroid Build Coastguard Worker double minimumDistance = previousDistance; 162*970e1046SAndroid Build Coastguard Worker int newClusterIndex = -1; 163*970e1046SAndroid Build Coastguard Worker for (int j = 0; j < clusterCount; j++) { 164*970e1046SAndroid Build Coastguard Worker if (distanceToIndexMatrix[previousClusterIndex][j].distance >= 4 * previousDistance) { 165*970e1046SAndroid Build Coastguard Worker continue; 166*970e1046SAndroid Build Coastguard Worker } 167*970e1046SAndroid Build Coastguard Worker double distance = pointProvider.distance(point, clusters[j]); 168*970e1046SAndroid Build Coastguard Worker if (distance < minimumDistance) { 169*970e1046SAndroid Build Coastguard Worker minimumDistance = distance; 170*970e1046SAndroid Build Coastguard Worker newClusterIndex = j; 171*970e1046SAndroid Build Coastguard Worker } 172*970e1046SAndroid Build Coastguard Worker } 173*970e1046SAndroid Build Coastguard Worker if (newClusterIndex != -1) { 174*970e1046SAndroid Build Coastguard Worker double distanceChange = 175*970e1046SAndroid Build Coastguard Worker Math.abs(Math.sqrt(minimumDistance) - Math.sqrt(previousDistance)); 176*970e1046SAndroid Build Coastguard Worker if (distanceChange > MIN_MOVEMENT_DISTANCE) { 177*970e1046SAndroid Build Coastguard Worker pointsMoved++; 178*970e1046SAndroid Build Coastguard Worker clusterIndices[i] = newClusterIndex; 179*970e1046SAndroid Build Coastguard Worker } 180*970e1046SAndroid Build Coastguard Worker } 181*970e1046SAndroid Build Coastguard Worker } 182*970e1046SAndroid Build Coastguard Worker 183*970e1046SAndroid Build Coastguard Worker if (pointsMoved == 0 && iteration != 0) { 184*970e1046SAndroid Build Coastguard Worker break; 185*970e1046SAndroid Build Coastguard Worker } 186*970e1046SAndroid Build Coastguard Worker 187*970e1046SAndroid Build Coastguard Worker double[] componentASums = new double[clusterCount]; 188*970e1046SAndroid Build Coastguard Worker double[] componentBSums = new double[clusterCount]; 189*970e1046SAndroid Build Coastguard Worker double[] componentCSums = new double[clusterCount]; 190*970e1046SAndroid Build Coastguard Worker Arrays.fill(pixelCountSums, 0); 191*970e1046SAndroid Build Coastguard Worker for (int i = 0; i < pointCount; i++) { 192*970e1046SAndroid Build Coastguard Worker int clusterIndex = clusterIndices[i]; 193*970e1046SAndroid Build Coastguard Worker double[] point = points[i]; 194*970e1046SAndroid Build Coastguard Worker int count = counts[i]; 195*970e1046SAndroid Build Coastguard Worker pixelCountSums[clusterIndex] += count; 196*970e1046SAndroid Build Coastguard Worker componentASums[clusterIndex] += (point[0] * count); 197*970e1046SAndroid Build Coastguard Worker componentBSums[clusterIndex] += (point[1] * count); 198*970e1046SAndroid Build Coastguard Worker componentCSums[clusterIndex] += (point[2] * count); 199*970e1046SAndroid Build Coastguard Worker } 200*970e1046SAndroid Build Coastguard Worker 201*970e1046SAndroid Build Coastguard Worker for (int i = 0; i < clusterCount; i++) { 202*970e1046SAndroid Build Coastguard Worker int count = pixelCountSums[i]; 203*970e1046SAndroid Build Coastguard Worker if (count == 0) { 204*970e1046SAndroid Build Coastguard Worker clusters[i] = new double[] {0., 0., 0.}; 205*970e1046SAndroid Build Coastguard Worker continue; 206*970e1046SAndroid Build Coastguard Worker } 207*970e1046SAndroid Build Coastguard Worker double a = componentASums[i] / count; 208*970e1046SAndroid Build Coastguard Worker double b = componentBSums[i] / count; 209*970e1046SAndroid Build Coastguard Worker double c = componentCSums[i] / count; 210*970e1046SAndroid Build Coastguard Worker clusters[i][0] = a; 211*970e1046SAndroid Build Coastguard Worker clusters[i][1] = b; 212*970e1046SAndroid Build Coastguard Worker clusters[i][2] = c; 213*970e1046SAndroid Build Coastguard Worker } 214*970e1046SAndroid Build Coastguard Worker } 215*970e1046SAndroid Build Coastguard Worker 216*970e1046SAndroid Build Coastguard Worker Map<Integer, Integer> argbToPopulation = new LinkedHashMap<>(); 217*970e1046SAndroid Build Coastguard Worker for (int i = 0; i < clusterCount; i++) { 218*970e1046SAndroid Build Coastguard Worker int count = pixelCountSums[i]; 219*970e1046SAndroid Build Coastguard Worker if (count == 0) { 220*970e1046SAndroid Build Coastguard Worker continue; 221*970e1046SAndroid Build Coastguard Worker } 222*970e1046SAndroid Build Coastguard Worker 223*970e1046SAndroid Build Coastguard Worker int possibleNewCluster = pointProvider.toInt(clusters[i]); 224*970e1046SAndroid Build Coastguard Worker if (argbToPopulation.containsKey(possibleNewCluster)) { 225*970e1046SAndroid Build Coastguard Worker continue; 226*970e1046SAndroid Build Coastguard Worker } 227*970e1046SAndroid Build Coastguard Worker 228*970e1046SAndroid Build Coastguard Worker argbToPopulation.put(possibleNewCluster, count); 229*970e1046SAndroid Build Coastguard Worker } 230*970e1046SAndroid Build Coastguard Worker 231*970e1046SAndroid Build Coastguard Worker return argbToPopulation; 232*970e1046SAndroid Build Coastguard Worker } 233*970e1046SAndroid Build Coastguard Worker } 234