blob: 0b501bfc04997188a3a5ab558d79e8b1e996e093 [file] [log] [blame]
// kmeans implements a generic k-means clustering algorithm.
//
// To use this code create types that implements Clusterable, Centroid, and
// also a function that implements CalculateCentroid. In many cases the same
// type can be used as both a Clusterable and a Centroid.
//
// See the unit tests for examples.
//
package kmeans
import "math"
// Clusterable defines the interface that an object must support to do k-means
// clustering on it.
type Clusterable interface{}
// Centroid is the interface that Centroids must support to do k-means clustering.
type Centroid interface {
// AsClusterable converts this Centroid to a Clusterable, or returns nil if
// the conversion isn't possible.
AsClusterable() Clusterable
// Distance returns the distance from the given Clusterable to this Centroid.
Distance(c Clusterable) float64
}
// CalculateCentroid calculates a new centroid from a list of Clusterables.
type CalculateCentroid func([]Clusterable) Centroid
// closestCentroid returns the index of the closest centroid to this observation.
func closestCentroid(observation Clusterable, centroids []Centroid) (int, float64) {
var bestDistance float64 = math.MaxFloat64
bestIndex := -1
for j, c := range centroids {
if dist := c.Distance(observation); dist < bestDistance {
bestDistance = dist
bestIndex = j
}
}
return bestIndex, bestDistance
}
// Do does a single iteration of Loyd's Algorithm, taking an array of
// observations and a set of centroids along with a function to calcaulate new
// centroids for a cluster. It returns an updated array of centroids. Note
// that the centroids array passed in gets modified so the best way to call the
// function is:
//
// centroids = Do(observations, centroids, f)
//
func Do(observations []Clusterable, centroids []Centroid, f CalculateCentroid) []Centroid {
k := len(centroids)
// cluster is which cluster each observation is currently in.
cluster := make([]int, len(observations))
// Find the closest centroid for each observation.
for i, o := range observations {
cluster[i], _ = closestCentroid(o, centroids)
}
newCentroids := make([]Centroid, 0, len(centroids))
// Calculate new centroids based on each the new cluster members.
for i := 0; i < k; i++ {
c := make([]Clusterable, 0)
for j, o := range observations {
if cluster[j] == i {
c = append(c, o)
}
}
if len(c) != 0 {
newCentroids = append(newCentroids, f(c))
}
}
return newCentroids
}
// GetClusters returns the observations categorized into the clusters they fit
// into. The return value is sorted by the number of members of the cluster.
// The very first element of each cluster is the centroid, the remainging
// members are the observations that are in the cluster.
func GetClusters(observations []Clusterable, centroids []Centroid) ([][]Clusterable, float64) {
r := make([][]Clusterable, len(centroids))
for i := range r {
// The first trace is always the centroid for the cluster.
cl := centroids[i].AsClusterable()
if cl != nil {
r[i] = []Clusterable{cl}
} else {
r[i] = []Clusterable{}
}
}
totalError := 0.0
for _, o := range observations {
index, clusterError := closestCentroid(o, centroids)
totalError += clusterError
r[index] = append(r[index], o)
}
return r, totalError
}
// KMeans runs the k-means clustering algorithm over a set of observations and
// returns the centroids and clusters.
//
// TODO(jcgregorio) Should just iterate until total error stops changing.
func KMeans(observations []Clusterable, centroids []Centroid, k, iters int, f CalculateCentroid) ([]Centroid, [][]Clusterable) {
for i := 0; i < iters; i++ {
centroids = Do(observations, centroids, f)
}
clusters, _ := GetClusters(observations, centroids)
return centroids, clusters
}
// TotalError calculates the total error between the centroids and the
// observations.
func TotalError(observations []Clusterable, centroids []Centroid) float64 {
_, totalError := GetClusters(observations, centroids)
return totalError
}