Moved clustering and observation structs to github.com/muesli/clusters

This commit is contained in:
Christian Muehlhaeuser
2018-06-02 14:06:07 +02:00
parent 0ad7a62e65
commit c58ae78401
8 changed files with 78 additions and 178 deletions

View File

@@ -19,12 +19,15 @@ to the cluster with the nearest mean, serving as a prototype of the cluster.
## Example ## Example
```go ```go
import "github.com/muesli/kmeans" import (
"github.com/muesli/kmeans"
"github.com/muesli/clusters"
)
// set up a random two-dimensional data set (float64 values between 0.0 and 1.0) // set up a random two-dimensional data set (float64 values between 0.0 and 1.0)
var d kmeans.Points var d clusters.Observations
for x := 0; x < 1024; x++ { for x := 0; x < 1024; x++ {
d = append(d, kmeans.Point{ d = append(d, clusters.Coordinates{
rand.Float64(), rand.Float64(),
rand.Float64(), rand.Float64(),
}) })
@@ -35,8 +38,8 @@ km := kmeans.New()
clusters, err := km.Partition(d, 16) clusters, err := km.Partition(d, 16)
for _, c := range clusters { for _, c := range clusters {
fmt.Printf("Centered at x: %.2f y: %.2f\n", c.Center[0]*255.0, c.Center[1]*255.0) fmt.Printf("Centered at x: %.2f y: %.2f\n", c.Center[0], c.Center[1])
fmt.Printf("Matching data points: %+v\n\n", c.Points) fmt.Printf("Matching data points: %+v\n\n", c.Observations)
} }
``` ```

View File

@@ -1,67 +0,0 @@
package kmeans
// A Cluster which data points gravitate around
type Cluster struct {
Center Point
Points Points
}
// Clusters is a slice of clusters
type Clusters []Cluster
// Nearest returns the index of the cluster nearest to point
func (c Clusters) Nearest(point Point) int {
var dist float64
var ci int
// Find the nearest cluster for this data point
for i, cluster := range c {
d := point.Distance(cluster.Center)
if dist == 0 || d < dist {
dist = d
ci = i
}
}
return ci
}
// recenter recenters a cluster
func (c *Cluster) recenter() {
center, err := c.Points.Mean()
if err != nil {
return
}
c.Center = center
}
// recenter recenters all clusters
func (c Clusters) recenter() {
for i := 0; i < len(c); i++ {
c[i].recenter()
}
}
// reset clears all point assignments
func (c Clusters) reset() {
for i := 0; i < len(c); i++ {
c[i].Points = Points{}
}
}
func (c *Cluster) pointsInDimension(n int) []float64 {
var v []float64
for _, p := range c.Points {
v = append(v, p[n])
}
return v
}
func (c Clusters) centersInDimension(n int) []float64 {
var v []float64
for _, cl := range c {
v = append(v, cl.Center[n])
}
return v
}

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"github.com/muesli/clusters"
"github.com/muesli/kmeans" "github.com/muesli/kmeans"
colorful "github.com/lucasb-eyer/go-colorful" colorful "github.com/lucasb-eyer/go-colorful"
@@ -30,18 +31,39 @@ var (
` `
) )
type Color struct {
color colorful.Color
}
func (c Color) Coordinates() clusters.Coordinates {
l, a, b := c.color.Lab()
return clusters.Coordinates{
l,
a,
b,
}
}
func (c Color) Distance(pos clusters.Coordinates) float64 {
c2 := colorful.Lab(pos[0], pos[1], pos[2])
return c.color.DistanceLab(c2)
}
func main() { func main() {
// Create data points in the CIE L*a*b color space // Create data points in the CIE L*a*b color space
// l for lightness channel // l for lightness channel
// a, b for color channels // a, b for color channels
var d kmeans.Points var d clusters.Observations
for l := 30; l < 230; l += 16 { for l := 0.2; l < 0.8; l += 0.05 {
for a := 0; a < 255; a += 16 { for a := -1.0; a < 1.0; a += 0.1 {
for b := 0; b < 255; b += 16 { for b := -1.0; b < 1.0; b += 0.1 {
d = append(d, kmeans.Point{ c := colorful.Lab(l, a, b)
float64(l) / 255.0, if !c.IsValid() {
float64(a) / 255.0, continue
float64(b) / 255.0, }
d = append(d, Color{
color: c,
}) })
} }
} }
@@ -59,7 +81,7 @@ func main() {
for i, c := range clusters { for i, c := range clusters {
fmt.Printf("Cluster: %d %+v\n", i, c.Center) fmt.Printf("Cluster: %d %+v\n", i, c.Center)
col := colorful.Lab(c.Center[0], -0.9+(c.Center[1]*1.8), -0.9+(c.Center[2]*1.8)).Clamped() col := colorful.Lab(c.Center[0], c.Center[1], c.Center[2]).Clamped()
fmt.Println("Color as Hex:", col.Hex()) fmt.Println("Color as Hex:", col.Hex())
buffer.Write([]byte(fmt.Sprintf(cell, col.Hex()))) buffer.Write([]byte(fmt.Sprintf(cell, col.Hex())))

View File

@@ -5,6 +5,7 @@ import (
"math/rand" "math/rand"
"time" "time"
"github.com/muesli/clusters"
"github.com/muesli/kmeans" "github.com/muesli/kmeans"
) )
@@ -12,9 +13,9 @@ func main() {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
// set up a random two-dimensional data set (float64 values between 0.0 and 1.0) // set up a random two-dimensional data set (float64 values between 0.0 and 1.0)
var d kmeans.Points var d clusters.Observations
for x := 0; x < 1024; x++ { for x := 0; x < 1024; x++ {
d = append(d, kmeans.Point{ d = append(d, clusters.Coordinates{
rand.Float64(), rand.Float64(),
rand.Float64(), rand.Float64(),
}) })
@@ -27,6 +28,6 @@ func main() {
for i, c := range clusters { for i, c := range clusters {
fmt.Printf("Cluster: %d\n", i) fmt.Printf("Cluster: %d\n", i)
fmt.Printf("Centered at x: %.2f y: %.2f\n", c.Center[0]*255.0, c.Center[1]*255.0) fmt.Printf("Centered at x: %.2f y: %.2f\n", c.Center[0], c.Center[1])
} }
} }

View File

@@ -5,7 +5,8 @@ package kmeans
import ( import (
"fmt" "fmt"
"math/rand" "math/rand"
"time"
"github.com/muesli/clusters"
) )
// Kmeans configuration/option struct // Kmeans configuration/option struct
@@ -39,39 +40,16 @@ func New() Kmeans {
return m return m
} }
func randomizeClusters(k int, dataset Points) (Clusters, error) {
var c Clusters
if len(dataset) == 0 || len(dataset[0]) == 0 {
return c, fmt.Errorf("there must be at least one dimension in the data set")
}
if k == 0 {
return c, fmt.Errorf("k must be greater than 0")
}
rand.Seed(time.Now().UnixNano())
for i := 0; i < k; i++ {
var p Point
for j := 0; j < len(dataset[0]); j++ {
p = append(p, rand.Float64())
}
c = append(c, Cluster{
Center: p,
})
}
return c, nil
}
// Partition executes the k-means algorithm on the given dataset and // Partition executes the k-means algorithm on the given dataset and
// partitions it into k clusters // partitions it into k clusters
func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) { func (m Kmeans) Partition(dataset clusters.Observations, k int) (clusters.Clusters, error) {
if k > len(dataset) { if k > len(dataset) {
return Clusters{}, fmt.Errorf("the size of the data set must at least equal k") return clusters.Clusters{}, fmt.Errorf("the size of the data set must at least equal k")
} }
clusters, err := randomizeClusters(k, dataset) cc, err := clusters.New(k, dataset)
if err != nil { if err != nil {
return Clusters{}, err return cc, err
} }
points := make([]int, len(dataset)) points := make([]int, len(dataset))
@@ -79,19 +57,19 @@ func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) {
for i := 0; changes > 0; i++ { for i := 0; changes > 0; i++ {
changes = 0 changes = 0
clusters.reset() cc.Reset()
for p, point := range dataset { for p, point := range dataset {
ci := clusters.Nearest(point) ci := cc.Nearest(point)
clusters[ci].Points = append(clusters[ci].Points, point) cc[ci].Append(point)
if points[p] != ci { if points[p] != ci {
points[p] = ci points[p] = ci
changes++ changes++
} }
} }
for ci := 0; ci < len(clusters); ci++ { for ci := 0; ci < len(cc); ci++ {
if len(clusters[ci].Points) == 0 { if len(cc[ci].Observations) == 0 {
// During the iterations, if any of the cluster centers has no // During the iterations, if any of the cluster centers has no
// data points associated with it, assign a random data point // data points associated with it, assign a random data point
// to it. // to it.
@@ -101,20 +79,20 @@ func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) {
// find a cluster with at least two data points, otherwise // find a cluster with at least two data points, otherwise
// we're just emptying one cluster to fill another // we're just emptying one cluster to fill another
ri = rand.Intn(len(dataset)) ri = rand.Intn(len(dataset))
if len(clusters[points[ri]].Points) > 1 { if len(cc[points[ri]].Observations) > 1 {
break break
} }
} }
clusters[ci].Points = append(clusters[ci].Points, dataset[ri]) cc[ci].Append(dataset[ri])
points[ri] = ci points[ri] = ci
} }
} }
if changes > 0 { if changes > 0 {
clusters.recenter() cc.Recenter()
} }
if m.plotter != nil { if m.plotter != nil {
m.plotter.Plot(clusters, i) m.plotter.Plot(cc, i)
} }
if i == m.iterationThreshold || if i == m.iterationThreshold ||
changes < int(float64(len(dataset))*m.deltaThreshold) { changes < int(float64(len(dataset))*m.deltaThreshold) {
@@ -123,5 +101,5 @@ func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) {
} }
} }
return clusters, nil return cc, nil
} }

View File

@@ -3,6 +3,8 @@ package kmeans
import ( import (
"math/rand" "math/rand"
"testing" "testing"
"github.com/muesli/clusters"
) )
var RANDOM_SEED = int64(42) var RANDOM_SEED = int64(42)
@@ -21,13 +23,14 @@ func TestNewErrors(t *testing.T) {
func TestPartitioningError(t *testing.T) { func TestPartitioningError(t *testing.T) {
km := New() km := New()
if _, err := km.Partition(Points{}, 1); err == nil { d := clusters.Observations{}
if _, err := km.Partition(d, 1); err == nil {
t.Errorf("Expected error partitioning with empty data set, got nil") t.Errorf("Expected error partitioning with empty data set, got nil")
return return
} }
d := Points{ d = clusters.Observations{
Point{ clusters.Coordinates{
0.1, 0.1,
0.1, 0.1,
}, },
@@ -44,10 +47,10 @@ func TestPartitioningError(t *testing.T) {
} }
func TestDimensions(t *testing.T) { func TestDimensions(t *testing.T) {
var d Points var d clusters.Observations
for x := 0; x < 255; x += 32 { for x := 0; x < 255; x += 32 {
for y := 0; y < 255; y += 32 { for y := 0; y < 255; y += 32 {
d = append(d, Point{ d = append(d, clusters.Coordinates{
float64(x) / 255.0, float64(x) / 255.0,
float64(y) / 255.0, float64(y) / 255.0,
}) })
@@ -69,10 +72,10 @@ func TestDimensions(t *testing.T) {
func benchmarkPartition(size, partitions int, b *testing.B) { func benchmarkPartition(size, partitions int, b *testing.B) {
rand.Seed(RANDOM_SEED) rand.Seed(RANDOM_SEED)
var d Points var d clusters.Observations
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
d = append(d, Point{ d = append(d, clusters.Coordinates{
rand.Float64(), rand.Float64(),
rand.Float64(), rand.Float64(),
}) })

View File

@@ -5,13 +5,15 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"github.com/muesli/clusters"
"github.com/wcharczuk/go-chart" "github.com/wcharczuk/go-chart"
"github.com/wcharczuk/go-chart/drawing" "github.com/wcharczuk/go-chart/drawing"
) )
// The Plotter interface lets you implement your own plotters // The Plotter interface lets you implement your own plotters
type Plotter interface { type Plotter interface {
Plot(clusters Clusters, iteration int) Plot(cc clusters.Clusters, iteration int)
} }
// SimplePlotter is the default standard plotter for 2-dimensional data sets // SimplePlotter is the default standard plotter for 2-dimensional data sets
@@ -33,19 +35,19 @@ var colors = []drawing.Color{
} }
// Plot draw a 2-dimensional data set into a PNG file named {iteration}.png // Plot draw a 2-dimensional data set into a PNG file named {iteration}.png
func (p SimplePlotter) Plot(clusters Clusters, iteration int) { func (p SimplePlotter) Plot(cc clusters.Clusters, iteration int) {
var series []chart.Series var series []chart.Series
// draw data points // draw data points
for i, c := range clusters { for i, c := range cc {
series = append(series, chart.ContinuousSeries{ series = append(series, chart.ContinuousSeries{
Style: chart.Style{ Style: chart.Style{
Show: true, Show: true,
StrokeWidth: chart.Disabled, StrokeWidth: chart.Disabled,
DotColor: colors[i%len(colors)], DotColor: colors[i%len(colors)],
DotWidth: 8}, DotWidth: 8},
XValues: c.pointsInDimension(0), XValues: c.PointsInDimension(0),
YValues: c.pointsInDimension(1), YValues: c.PointsInDimension(1),
}) })
} }
@@ -57,8 +59,8 @@ func (p SimplePlotter) Plot(clusters Clusters, iteration int) {
DotColor: drawing.ColorBlack, DotColor: drawing.ColorBlack,
DotWidth: 16, DotWidth: 16,
}, },
XValues: clusters.centersInDimension(0), XValues: cc.CentersInDimension(0),
YValues: clusters.centersInDimension(1), YValues: cc.CentersInDimension(1),
}) })
graph := chart.Chart{ graph := chart.Chart{

View File

@@ -1,42 +0,0 @@
package kmeans
import (
"fmt"
"math"
)
// Point is a data point (float64 between 0.0 and 1.0) in n dimensions
type Point []float64
// Points is a slice of points
type Points []Point
// Distance returns the euclidean distance between two data points
func (p Point) Distance(p2 Point) float64 {
var r float64
for i, v := range p {
r += math.Pow(v-p2[i], 2)
}
return r
}
// Mean returns the mean point of p
func (p Points) Mean() (Point, error) {
var l = len(p)
if l == 0 {
return Point{}, fmt.Errorf("there is no mean for an empty set of points")
}
c := make([]float64, len(p[0]))
for _, point := range p {
for j, v := range point {
c[j] += v
}
}
var point Point
for _, v := range c {
point = append(point, v/float64(l))
}
return point, nil
}