diff --git a/README.md b/README.md index f8848fa..a20cd32 100644 --- a/README.md +++ b/README.md @@ -19,12 +19,15 @@ to the cluster with the nearest mean, serving as a prototype of the cluster. ## Example ```go -import "github.com/muesli/kmeans" +import ( + "github.com/muesli/kmeans" + "github.com/muesli/clusters" +) // set up a random two-dimensional data set (float64 values between 0.0 and 1.0) -var d kmeans.Points +var d clusters.Observations for x := 0; x < 1024; x++ { - d = append(d, kmeans.Point{ + d = append(d, clusters.Coordinates{ rand.Float64(), rand.Float64(), }) @@ -35,8 +38,8 @@ km := kmeans.New() clusters, err := km.Partition(d, 16) for _, c := range clusters { - fmt.Printf("Centered at x: %.2f y: %.2f\n", c.Center[0]*255.0, c.Center[1]*255.0) - fmt.Printf("Matching data points: %+v\n\n", c.Points) + fmt.Printf("Centered at x: %.2f y: %.2f\n", c.Center[0], c.Center[1]) + fmt.Printf("Matching data points: %+v\n\n", c.Observations) } ``` diff --git a/cluster.go b/cluster.go deleted file mode 100644 index a960936..0000000 --- a/cluster.go +++ /dev/null @@ -1,67 +0,0 @@ -package kmeans - -// A Cluster which data points gravitate around -type Cluster struct { - Center Point - Points Points -} - -// Clusters is a slice of clusters -type Clusters []Cluster - -// Nearest returns the index of the cluster nearest to point -func (c Clusters) Nearest(point Point) int { - var dist float64 - var ci int - - // Find the nearest cluster for this data point - for i, cluster := range c { - d := point.Distance(cluster.Center) - if dist == 0 || d < dist { - dist = d - ci = i - } - } - - return ci -} - -// recenter recenters a cluster -func (c *Cluster) recenter() { - center, err := c.Points.Mean() - if err != nil { - return - } - - c.Center = center -} - -// recenter recenters all clusters -func (c Clusters) recenter() { - for i := 0; i < len(c); i++ { - c[i].recenter() - } -} - -// reset clears all point assignments -func (c Clusters) reset() { - for i := 0; i < len(c); i++ { - c[i].Points = Points{} - } -} - -func (c *Cluster) pointsInDimension(n int) []float64 { - var v []float64 - for _, p := range c.Points { - v = append(v, p[n]) - } - return v -} - -func (c Clusters) centersInDimension(n int) []float64 { - var v []float64 - for _, cl := range c { - v = append(v, cl.Center[n]) - } - return v -} diff --git a/examples/colorpalette/main.go b/examples/colorpalette/main.go index bbba3bb..fc998a3 100644 --- a/examples/colorpalette/main.go +++ b/examples/colorpalette/main.go @@ -5,6 +5,7 @@ import ( "fmt" "io/ioutil" + "github.com/muesli/clusters" "github.com/muesli/kmeans" colorful "github.com/lucasb-eyer/go-colorful" @@ -30,18 +31,39 @@ var ( ` ) +type Color struct { + color colorful.Color +} + +func (c Color) Coordinates() clusters.Coordinates { + l, a, b := c.color.Lab() + return clusters.Coordinates{ + l, + a, + b, + } +} + +func (c Color) Distance(pos clusters.Coordinates) float64 { + c2 := colorful.Lab(pos[0], pos[1], pos[2]) + return c.color.DistanceLab(c2) +} + func main() { // Create data points in the CIE L*a*b color space // l for lightness channel // a, b for color channels - var d kmeans.Points - for l := 30; l < 230; l += 16 { - for a := 0; a < 255; a += 16 { - for b := 0; b < 255; b += 16 { - d = append(d, kmeans.Point{ - float64(l) / 255.0, - float64(a) / 255.0, - float64(b) / 255.0, + var d clusters.Observations + for l := 0.2; l < 0.8; l += 0.05 { + for a := -1.0; a < 1.0; a += 0.1 { + for b := -1.0; b < 1.0; b += 0.1 { + c := colorful.Lab(l, a, b) + if !c.IsValid() { + continue + } + + d = append(d, Color{ + color: c, }) } } @@ -59,7 +81,7 @@ func main() { for i, c := range clusters { fmt.Printf("Cluster: %d %+v\n", i, c.Center) - col := colorful.Lab(c.Center[0], -0.9+(c.Center[1]*1.8), -0.9+(c.Center[2]*1.8)).Clamped() + col := colorful.Lab(c.Center[0], c.Center[1], c.Center[2]).Clamped() fmt.Println("Color as Hex:", col.Hex()) buffer.Write([]byte(fmt.Sprintf(cell, col.Hex()))) diff --git a/examples/simple/main.go b/examples/simple/main.go index bb5899a..e77f695 100644 --- a/examples/simple/main.go +++ b/examples/simple/main.go @@ -5,6 +5,7 @@ import ( "math/rand" "time" + "github.com/muesli/clusters" "github.com/muesli/kmeans" ) @@ -12,9 +13,9 @@ func main() { rand.Seed(time.Now().UnixNano()) // set up a random two-dimensional data set (float64 values between 0.0 and 1.0) - var d kmeans.Points + var d clusters.Observations for x := 0; x < 1024; x++ { - d = append(d, kmeans.Point{ + d = append(d, clusters.Coordinates{ rand.Float64(), rand.Float64(), }) @@ -27,6 +28,6 @@ func main() { for i, c := range clusters { fmt.Printf("Cluster: %d\n", i) - fmt.Printf("Centered at x: %.2f y: %.2f\n", c.Center[0]*255.0, c.Center[1]*255.0) + fmt.Printf("Centered at x: %.2f y: %.2f\n", c.Center[0], c.Center[1]) } } diff --git a/kmeans.go b/kmeans.go index 9d2a1ef..11cd7bb 100644 --- a/kmeans.go +++ b/kmeans.go @@ -5,7 +5,8 @@ package kmeans import ( "fmt" "math/rand" - "time" + + "github.com/muesli/clusters" ) // Kmeans configuration/option struct @@ -39,39 +40,16 @@ func New() Kmeans { return m } -func randomizeClusters(k int, dataset Points) (Clusters, error) { - var c Clusters - if len(dataset) == 0 || len(dataset[0]) == 0 { - return c, fmt.Errorf("there must be at least one dimension in the data set") - } - if k == 0 { - return c, fmt.Errorf("k must be greater than 0") - } - - rand.Seed(time.Now().UnixNano()) - for i := 0; i < k; i++ { - var p Point - for j := 0; j < len(dataset[0]); j++ { - p = append(p, rand.Float64()) - } - - c = append(c, Cluster{ - Center: p, - }) - } - return c, nil -} - // Partition executes the k-means algorithm on the given dataset and // partitions it into k clusters -func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) { +func (m Kmeans) Partition(dataset clusters.Observations, k int) (clusters.Clusters, error) { if k > len(dataset) { - return Clusters{}, fmt.Errorf("the size of the data set must at least equal k") + return clusters.Clusters{}, fmt.Errorf("the size of the data set must at least equal k") } - clusters, err := randomizeClusters(k, dataset) + cc, err := clusters.New(k, dataset) if err != nil { - return Clusters{}, err + return cc, err } points := make([]int, len(dataset)) @@ -79,19 +57,19 @@ func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) { for i := 0; changes > 0; i++ { changes = 0 - clusters.reset() + cc.Reset() for p, point := range dataset { - ci := clusters.Nearest(point) - clusters[ci].Points = append(clusters[ci].Points, point) + ci := cc.Nearest(point) + cc[ci].Append(point) if points[p] != ci { points[p] = ci changes++ } } - for ci := 0; ci < len(clusters); ci++ { - if len(clusters[ci].Points) == 0 { + for ci := 0; ci < len(cc); ci++ { + if len(cc[ci].Observations) == 0 { // During the iterations, if any of the cluster centers has no // data points associated with it, assign a random data point // to it. @@ -101,20 +79,20 @@ func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) { // find a cluster with at least two data points, otherwise // we're just emptying one cluster to fill another ri = rand.Intn(len(dataset)) - if len(clusters[points[ri]].Points) > 1 { + if len(cc[points[ri]].Observations) > 1 { break } } - clusters[ci].Points = append(clusters[ci].Points, dataset[ri]) + cc[ci].Append(dataset[ri]) points[ri] = ci } } if changes > 0 { - clusters.recenter() + cc.Recenter() } if m.plotter != nil { - m.plotter.Plot(clusters, i) + m.plotter.Plot(cc, i) } if i == m.iterationThreshold || changes < int(float64(len(dataset))*m.deltaThreshold) { @@ -123,5 +101,5 @@ func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) { } } - return clusters, nil + return cc, nil } diff --git a/kmeans_test.go b/kmeans_test.go index cea6071..4b2ea16 100644 --- a/kmeans_test.go +++ b/kmeans_test.go @@ -3,6 +3,8 @@ package kmeans import ( "math/rand" "testing" + + "github.com/muesli/clusters" ) var RANDOM_SEED = int64(42) @@ -21,13 +23,14 @@ func TestNewErrors(t *testing.T) { func TestPartitioningError(t *testing.T) { km := New() - if _, err := km.Partition(Points{}, 1); err == nil { + d := clusters.Observations{} + if _, err := km.Partition(d, 1); err == nil { t.Errorf("Expected error partitioning with empty data set, got nil") return } - d := Points{ - Point{ + d = clusters.Observations{ + clusters.Coordinates{ 0.1, 0.1, }, @@ -44,10 +47,10 @@ func TestPartitioningError(t *testing.T) { } func TestDimensions(t *testing.T) { - var d Points + var d clusters.Observations for x := 0; x < 255; x += 32 { for y := 0; y < 255; y += 32 { - d = append(d, Point{ + d = append(d, clusters.Coordinates{ float64(x) / 255.0, float64(y) / 255.0, }) @@ -69,10 +72,10 @@ func TestDimensions(t *testing.T) { func benchmarkPartition(size, partitions int, b *testing.B) { rand.Seed(RANDOM_SEED) - var d Points + var d clusters.Observations for i := 0; i < size; i++ { - d = append(d, Point{ + d = append(d, clusters.Coordinates{ rand.Float64(), rand.Float64(), }) diff --git a/plotter.go b/plotter.go index 24d21bc..81000cf 100644 --- a/plotter.go +++ b/plotter.go @@ -5,13 +5,15 @@ import ( "fmt" "io/ioutil" + "github.com/muesli/clusters" + "github.com/wcharczuk/go-chart" "github.com/wcharczuk/go-chart/drawing" ) // The Plotter interface lets you implement your own plotters type Plotter interface { - Plot(clusters Clusters, iteration int) + Plot(cc clusters.Clusters, iteration int) } // SimplePlotter is the default standard plotter for 2-dimensional data sets @@ -33,19 +35,19 @@ var colors = []drawing.Color{ } // Plot draw a 2-dimensional data set into a PNG file named {iteration}.png -func (p SimplePlotter) Plot(clusters Clusters, iteration int) { +func (p SimplePlotter) Plot(cc clusters.Clusters, iteration int) { var series []chart.Series // draw data points - for i, c := range clusters { + for i, c := range cc { series = append(series, chart.ContinuousSeries{ Style: chart.Style{ Show: true, StrokeWidth: chart.Disabled, DotColor: colors[i%len(colors)], DotWidth: 8}, - XValues: c.pointsInDimension(0), - YValues: c.pointsInDimension(1), + XValues: c.PointsInDimension(0), + YValues: c.PointsInDimension(1), }) } @@ -57,8 +59,8 @@ func (p SimplePlotter) Plot(clusters Clusters, iteration int) { DotColor: drawing.ColorBlack, DotWidth: 16, }, - XValues: clusters.centersInDimension(0), - YValues: clusters.centersInDimension(1), + XValues: cc.CentersInDimension(0), + YValues: cc.CentersInDimension(1), }) graph := chart.Chart{ diff --git a/point.go b/point.go deleted file mode 100644 index 2ae0fcc..0000000 --- a/point.go +++ /dev/null @@ -1,42 +0,0 @@ -package kmeans - -import ( - "fmt" - "math" -) - -// Point is a data point (float64 between 0.0 and 1.0) in n dimensions -type Point []float64 - -// Points is a slice of points -type Points []Point - -// Distance returns the euclidean distance between two data points -func (p Point) Distance(p2 Point) float64 { - var r float64 - for i, v := range p { - r += math.Pow(v-p2[i], 2) - } - return r -} - -// Mean returns the mean point of p -func (p Points) Mean() (Point, error) { - var l = len(p) - if l == 0 { - return Point{}, fmt.Errorf("there is no mean for an empty set of points") - } - - c := make([]float64, len(p[0])) - for _, point := range p { - for j, v := range point { - c[j] += v - } - } - - var point Point - for _, v := range c { - point = append(point, v/float64(l)) - } - return point, nil -}