blob: 029ec46acd11d2f44436171af98df3b08369be46 [file] [log] [blame]
package kmeans
import (
"math"
"testing"
)
// myObservation implements Clusterable.
type myObservation struct {
x float64
y float64
}
func (m myObservation) Distance(other Clusterable) float64 {
o := other.(myObservation)
return math.Sqrt((m.x-o.x)*(m.x-o.x) + (m.y-o.y)*(m.y-o.y))
}
// calculateCentroid implements CalculateCentroid.
func calculateCentroid(members []Clusterable) Clusterable {
var sumX = 0.0
var sumY = 0.0
length := float64(len(members))
for _, m := range members {
sumX += m.(myObservation).x
sumY += m.(myObservation).y
}
return myObservation{x: sumX / length, y: sumY / length}
}
func near(a, b float64) bool {
return math.Abs(a-b) < 0.001
}
func almostEqual(t *testing.T, a, b Clusterable) {
if got, want := a.(myObservation).x, b.(myObservation).x; !near(got, want) {
t.Errorf("Not near enough on the x: Got %f Want %f", got, want)
}
if got, want := a.(myObservation).y, b.(myObservation).y; !near(got, want) {
t.Errorf("Not near enough on the x: Got %f Want %f", got, want)
}
}
func TestBasicIteration(t *testing.T) {
observations := []Clusterable{
myObservation{0.0, 0.0},
myObservation{3.0, 0.0},
myObservation{3.0, 1.0},
}
centroids := []Clusterable{
myObservation{0.0, 0.0},
myObservation{3.0, 0.0},
}
centroids = Do(observations, centroids, calculateCentroid)
almostEqual(t, centroids[0], myObservation{0.0, 0.0})
almostEqual(t, centroids[1], myObservation{3.0, 0.5})
}
func TestEmptyCentroids(t *testing.T) {
observations := []Clusterable{
myObservation{0.0, 0.0},
myObservation{3.0, 0.0},
myObservation{3.0, 1.0},
}
centroids := []Clusterable{}
centroids = Do(observations, centroids, calculateCentroid)
if got, want := len(centroids), 0; got != want {
t.Errorf("Wrong length of centroids returned: Got %d, Want %d", got, want)
}
}
func TestEmptyEverything(t *testing.T) {
observations := []Clusterable{}
centroids := []Clusterable{}
centroids = Do(observations, centroids, calculateCentroid)
if got, want := len(centroids), 0; got != want {
t.Errorf("Wrong length of centroids returned: Got %d, Want %d", got, want)
}
}
func TestLosingCentroids(t *testing.T) {
observations := []Clusterable{
myObservation{0.0, 0.0},
myObservation{3.0, 0.0},
}
centroids := []Clusterable{
myObservation{0.0, 0.0},
myObservation{3.0, 0.0},
myObservation{3.0, 1.0},
}
centroids = Do(observations, centroids, calculateCentroid)
if got, want := len(centroids), 2; got != want {
t.Errorf("Wrong length of centroids returned: Got %d, Want %d", got, want)
}
}
func TestFullKmeans(t *testing.T) {
observations := []Clusterable{
myObservation{0.0, 0.0},
myObservation{3.0, 0.0},
myObservation{3.0, 0.5},
myObservation{6.0, 6.0},
myObservation{6.0, 6.1},
myObservation{6.0, 6.2},
}
centroids := []Clusterable{
myObservation{0.0, 0.0},
myObservation{3.0, 0.0},
myObservation{6.0, 6.0},
}
centroids = Do(observations, centroids, calculateCentroid)
centroids = Do(observations, centroids, calculateCentroid)
centroids = Do(observations, centroids, calculateCentroid)
clusters, _ := GetClusters(observations, centroids)
if got, want := len(centroids), 3; got != want {
t.Errorf("Wrong length of centroids: Got %d, Want %d", got, want)
}
if got, want := len(clusters[0]), 3+1; got != want {
t.Errorf("Wrong length of clusters[0]: Got %d, Want %d", got, want)
}
if got, want := len(clusters[1]), 2+1; got != want {
t.Errorf("Wrong length of clusters[1]: Got %d, Want %d", got, want)
}
if got, want := len(clusters[2]), 1+1; got != want {
t.Errorf("Wrong length of clusters[2]: Got %d, Want %d", got, want)
}
}