/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License */ package com.android.internal.ml.clustering; import android.annotation.NonNull; import android.util.Log; import com.android.internal.annotations.VisibleForTesting; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; /** * Simple K-Means implementation */ public class KMeans { private static final boolean DEBUG = false; private static final String TAG = "KMeans"; private final Random mRandomState; private final int mMaxIterations; private float mSqConvergenceEpsilon; public KMeans() { this(new Random()); } public KMeans(Random random) { this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */); } public KMeans(Random random, int maxIterations, float convergenceEpsilon) { mRandomState = random; mMaxIterations = maxIterations; mSqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon; } /** * Runs k-means on the input data (X) trying to find k means. * * K-Means is known for getting stuck into local optima, so you might * want to run it multiple time and argmax on {@link KMeans#score(List)} * * @param k The number of points to return. * @param inputData Input data. * @return An array of k Means, each representing a centroid and data points that belong to it. */ public List predict(final int k, final float[][] inputData) { checkDataSetSanity(inputData); int dimension = inputData[0].length; final ArrayList means = new ArrayList<>(); for (int i = 0; i < k; i++) { Mean m = new Mean(dimension); for (int j = 0; j < dimension; j++) { m.mCentroid[j] = mRandomState.nextFloat(); } means.add(m); } // Iterate until we converge or run out of iterations boolean converged = false; for (int i = 0; i < mMaxIterations; i++) { converged = step(means, inputData); if (converged) { if (DEBUG) Log.d(TAG, "Converged at iteration: " + i); break; } } if (!converged && DEBUG) Log.d(TAG, "Did not converge"); return means; } /** * Score calculates the inertia between means. * This can be considered as an E step of an EM algorithm. * * @param means Means to use when calculating score. * @return The score */ public static double score(@NonNull List means) { double score = 0; final int meansSize = means.size(); for (int i = 0; i < meansSize; i++) { Mean mean = means.get(i); for (int j = 0; j < meansSize; j++) { Mean compareTo = means.get(j); if (mean == compareTo) { continue; } double distance = Math.sqrt(sqDistance(mean.mCentroid, compareTo.mCentroid)); score += distance; } } return score; } @VisibleForTesting public void checkDataSetSanity(float[][] inputData) { if (inputData == null) { throw new IllegalArgumentException("Data set is null."); } else if (inputData.length == 0) { throw new IllegalArgumentException("Data set is empty."); } else if (inputData[0] == null) { throw new IllegalArgumentException("Bad data set format."); } final int dimension = inputData[0].length; final int length = inputData.length; for (int i = 1; i < length; i++) { if (inputData[i] == null || inputData[i].length != dimension) { throw new IllegalArgumentException("Bad data set format."); } } } /** * K-Means iteration. * * @param means Current means * @param inputData Input data * @return True if data set converged */ private boolean step(final ArrayList means, final float[][] inputData) { // Clean up the previous state because we need to compute // which point belongs to each mean again. for (int i = means.size() - 1; i >= 0; i--) { final Mean mean = means.get(i); mean.mClosestItems.clear(); } for (int i = inputData.length - 1; i >= 0; i--) { final float[] current = inputData[i]; final Mean nearest = nearestMean(current, means); nearest.mClosestItems.add(current); } boolean converged = true; // Move each mean towards the nearest data set points for (int i = means.size() - 1; i >= 0; i--) { final Mean mean = means.get(i); if (mean.mClosestItems.size() == 0) { continue; } // Compute the new mean centroid: // 1. Sum all all points // 2. Average them final float[] oldCentroid = mean.mCentroid; mean.mCentroid = new float[oldCentroid.length]; for (int j = 0; j < mean.mClosestItems.size(); j++) { // Update each centroid component for (int p = 0; p < mean.mCentroid.length; p++) { mean.mCentroid[p] += mean.mClosestItems.get(j)[p]; } } for (int j = 0; j < mean.mCentroid.length; j++) { mean.mCentroid[j] /= mean.mClosestItems.size(); } // We converged if the centroid didn't move for any of the means. if (sqDistance(oldCentroid, mean.mCentroid) > mSqConvergenceEpsilon) { converged = false; } } return converged; } @VisibleForTesting public static Mean nearestMean(float[] point, List means) { Mean nearest = null; float nearestDistance = Float.MAX_VALUE; final int meanCount = means.size(); for (int i = 0; i < meanCount; i++) { Mean next = means.get(i); // We don't need the sqrt when comparing distances in euclidean space // because they exist on both sides of the equation and cancel each other out. float nextDistance = sqDistance(point, next.mCentroid); if (nextDistance < nearestDistance) { nearest = next; nearestDistance = nextDistance; } } return nearest; } @VisibleForTesting public static float sqDistance(float[] a, float[] b) { float dist = 0; final int length = a.length; for (int i = 0; i < length; i++) { dist += (a[i] - b[i]) * (a[i] - b[i]); } return dist; } /** * Definition of a mean, contains a centroid and points on its cluster. */ public static class Mean { float[] mCentroid; final ArrayList mClosestItems = new ArrayList<>(); public Mean(int dimension) { mCentroid = new float[dimension]; } public Mean(float ...centroid) { mCentroid = centroid; } public float[] getCentroid() { return mCentroid; } public List getItems() { return mClosestItems; } @Override public String toString() { return "Mean(centroid: " + Arrays.toString(mCentroid) + ", size: " + mClosestItems.size() + ")"; } } }