mirror of
https://github.com/muesli/kmeans.git
synced 2025-10-11 10:30:14 +08:00
Moved clustering and observation structs to github.com/muesli/clusters
This commit is contained in:
13
README.md
13
README.md
@@ -19,12 +19,15 @@ to the cluster with the nearest mean, serving as a prototype of the cluster.
|
|||||||
## Example
|
## Example
|
||||||
|
|
||||||
```go
|
```go
|
||||||
import "github.com/muesli/kmeans"
|
import (
|
||||||
|
"github.com/muesli/kmeans"
|
||||||
|
"github.com/muesli/clusters"
|
||||||
|
)
|
||||||
|
|
||||||
// set up a random two-dimensional data set (float64 values between 0.0 and 1.0)
|
// set up a random two-dimensional data set (float64 values between 0.0 and 1.0)
|
||||||
var d kmeans.Points
|
var d clusters.Observations
|
||||||
for x := 0; x < 1024; x++ {
|
for x := 0; x < 1024; x++ {
|
||||||
d = append(d, kmeans.Point{
|
d = append(d, clusters.Coordinates{
|
||||||
rand.Float64(),
|
rand.Float64(),
|
||||||
rand.Float64(),
|
rand.Float64(),
|
||||||
})
|
})
|
||||||
@@ -35,8 +38,8 @@ km := kmeans.New()
|
|||||||
clusters, err := km.Partition(d, 16)
|
clusters, err := km.Partition(d, 16)
|
||||||
|
|
||||||
for _, c := range clusters {
|
for _, c := range clusters {
|
||||||
fmt.Printf("Centered at x: %.2f y: %.2f\n", c.Center[0]*255.0, c.Center[1]*255.0)
|
fmt.Printf("Centered at x: %.2f y: %.2f\n", c.Center[0], c.Center[1])
|
||||||
fmt.Printf("Matching data points: %+v\n\n", c.Points)
|
fmt.Printf("Matching data points: %+v\n\n", c.Observations)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
67
cluster.go
67
cluster.go
@@ -1,67 +0,0 @@
|
|||||||
package kmeans
|
|
||||||
|
|
||||||
// A Cluster which data points gravitate around
|
|
||||||
type Cluster struct {
|
|
||||||
Center Point
|
|
||||||
Points Points
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clusters is a slice of clusters
|
|
||||||
type Clusters []Cluster
|
|
||||||
|
|
||||||
// Nearest returns the index of the cluster nearest to point
|
|
||||||
func (c Clusters) Nearest(point Point) int {
|
|
||||||
var dist float64
|
|
||||||
var ci int
|
|
||||||
|
|
||||||
// Find the nearest cluster for this data point
|
|
||||||
for i, cluster := range c {
|
|
||||||
d := point.Distance(cluster.Center)
|
|
||||||
if dist == 0 || d < dist {
|
|
||||||
dist = d
|
|
||||||
ci = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ci
|
|
||||||
}
|
|
||||||
|
|
||||||
// recenter recenters a cluster
|
|
||||||
func (c *Cluster) recenter() {
|
|
||||||
center, err := c.Points.Mean()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Center = center
|
|
||||||
}
|
|
||||||
|
|
||||||
// recenter recenters all clusters
|
|
||||||
func (c Clusters) recenter() {
|
|
||||||
for i := 0; i < len(c); i++ {
|
|
||||||
c[i].recenter()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// reset clears all point assignments
|
|
||||||
func (c Clusters) reset() {
|
|
||||||
for i := 0; i < len(c); i++ {
|
|
||||||
c[i].Points = Points{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Cluster) pointsInDimension(n int) []float64 {
|
|
||||||
var v []float64
|
|
||||||
for _, p := range c.Points {
|
|
||||||
v = append(v, p[n])
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Clusters) centersInDimension(n int) []float64 {
|
|
||||||
var v []float64
|
|
||||||
for _, cl := range c {
|
|
||||||
v = append(v, cl.Center[n])
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
|
||||||
|
"github.com/muesli/clusters"
|
||||||
"github.com/muesli/kmeans"
|
"github.com/muesli/kmeans"
|
||||||
|
|
||||||
colorful "github.com/lucasb-eyer/go-colorful"
|
colorful "github.com/lucasb-eyer/go-colorful"
|
||||||
@@ -30,18 +31,39 @@ var (
|
|||||||
`
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Color struct {
|
||||||
|
color colorful.Color
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Color) Coordinates() clusters.Coordinates {
|
||||||
|
l, a, b := c.color.Lab()
|
||||||
|
return clusters.Coordinates{
|
||||||
|
l,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Color) Distance(pos clusters.Coordinates) float64 {
|
||||||
|
c2 := colorful.Lab(pos[0], pos[1], pos[2])
|
||||||
|
return c.color.DistanceLab(c2)
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Create data points in the CIE L*a*b color space
|
// Create data points in the CIE L*a*b color space
|
||||||
// l for lightness channel
|
// l for lightness channel
|
||||||
// a, b for color channels
|
// a, b for color channels
|
||||||
var d kmeans.Points
|
var d clusters.Observations
|
||||||
for l := 30; l < 230; l += 16 {
|
for l := 0.2; l < 0.8; l += 0.05 {
|
||||||
for a := 0; a < 255; a += 16 {
|
for a := -1.0; a < 1.0; a += 0.1 {
|
||||||
for b := 0; b < 255; b += 16 {
|
for b := -1.0; b < 1.0; b += 0.1 {
|
||||||
d = append(d, kmeans.Point{
|
c := colorful.Lab(l, a, b)
|
||||||
float64(l) / 255.0,
|
if !c.IsValid() {
|
||||||
float64(a) / 255.0,
|
continue
|
||||||
float64(b) / 255.0,
|
}
|
||||||
|
|
||||||
|
d = append(d, Color{
|
||||||
|
color: c,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -59,7 +81,7 @@ func main() {
|
|||||||
|
|
||||||
for i, c := range clusters {
|
for i, c := range clusters {
|
||||||
fmt.Printf("Cluster: %d %+v\n", i, c.Center)
|
fmt.Printf("Cluster: %d %+v\n", i, c.Center)
|
||||||
col := colorful.Lab(c.Center[0], -0.9+(c.Center[1]*1.8), -0.9+(c.Center[2]*1.8)).Clamped()
|
col := colorful.Lab(c.Center[0], c.Center[1], c.Center[2]).Clamped()
|
||||||
fmt.Println("Color as Hex:", col.Hex())
|
fmt.Println("Color as Hex:", col.Hex())
|
||||||
|
|
||||||
buffer.Write([]byte(fmt.Sprintf(cell, col.Hex())))
|
buffer.Write([]byte(fmt.Sprintf(cell, col.Hex())))
|
||||||
|
@@ -5,6 +5,7 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/muesli/clusters"
|
||||||
"github.com/muesli/kmeans"
|
"github.com/muesli/kmeans"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -12,9 +13,9 @@ func main() {
|
|||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
|
||||||
// set up a random two-dimensional data set (float64 values between 0.0 and 1.0)
|
// set up a random two-dimensional data set (float64 values between 0.0 and 1.0)
|
||||||
var d kmeans.Points
|
var d clusters.Observations
|
||||||
for x := 0; x < 1024; x++ {
|
for x := 0; x < 1024; x++ {
|
||||||
d = append(d, kmeans.Point{
|
d = append(d, clusters.Coordinates{
|
||||||
rand.Float64(),
|
rand.Float64(),
|
||||||
rand.Float64(),
|
rand.Float64(),
|
||||||
})
|
})
|
||||||
@@ -27,6 +28,6 @@ func main() {
|
|||||||
|
|
||||||
for i, c := range clusters {
|
for i, c := range clusters {
|
||||||
fmt.Printf("Cluster: %d\n", i)
|
fmt.Printf("Cluster: %d\n", i)
|
||||||
fmt.Printf("Centered at x: %.2f y: %.2f\n", c.Center[0]*255.0, c.Center[1]*255.0)
|
fmt.Printf("Centered at x: %.2f y: %.2f\n", c.Center[0], c.Center[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
54
kmeans.go
54
kmeans.go
@@ -5,7 +5,8 @@ package kmeans
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"time"
|
|
||||||
|
"github.com/muesli/clusters"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Kmeans configuration/option struct
|
// Kmeans configuration/option struct
|
||||||
@@ -39,39 +40,16 @@ func New() Kmeans {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func randomizeClusters(k int, dataset Points) (Clusters, error) {
|
|
||||||
var c Clusters
|
|
||||||
if len(dataset) == 0 || len(dataset[0]) == 0 {
|
|
||||||
return c, fmt.Errorf("there must be at least one dimension in the data set")
|
|
||||||
}
|
|
||||||
if k == 0 {
|
|
||||||
return c, fmt.Errorf("k must be greater than 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
for i := 0; i < k; i++ {
|
|
||||||
var p Point
|
|
||||||
for j := 0; j < len(dataset[0]); j++ {
|
|
||||||
p = append(p, rand.Float64())
|
|
||||||
}
|
|
||||||
|
|
||||||
c = append(c, Cluster{
|
|
||||||
Center: p,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Partition executes the k-means algorithm on the given dataset and
|
// Partition executes the k-means algorithm on the given dataset and
|
||||||
// partitions it into k clusters
|
// partitions it into k clusters
|
||||||
func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) {
|
func (m Kmeans) Partition(dataset clusters.Observations, k int) (clusters.Clusters, error) {
|
||||||
if k > len(dataset) {
|
if k > len(dataset) {
|
||||||
return Clusters{}, fmt.Errorf("the size of the data set must at least equal k")
|
return clusters.Clusters{}, fmt.Errorf("the size of the data set must at least equal k")
|
||||||
}
|
}
|
||||||
|
|
||||||
clusters, err := randomizeClusters(k, dataset)
|
cc, err := clusters.New(k, dataset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Clusters{}, err
|
return cc, err
|
||||||
}
|
}
|
||||||
|
|
||||||
points := make([]int, len(dataset))
|
points := make([]int, len(dataset))
|
||||||
@@ -79,19 +57,19 @@ func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) {
|
|||||||
|
|
||||||
for i := 0; changes > 0; i++ {
|
for i := 0; changes > 0; i++ {
|
||||||
changes = 0
|
changes = 0
|
||||||
clusters.reset()
|
cc.Reset()
|
||||||
|
|
||||||
for p, point := range dataset {
|
for p, point := range dataset {
|
||||||
ci := clusters.Nearest(point)
|
ci := cc.Nearest(point)
|
||||||
clusters[ci].Points = append(clusters[ci].Points, point)
|
cc[ci].Append(point)
|
||||||
if points[p] != ci {
|
if points[p] != ci {
|
||||||
points[p] = ci
|
points[p] = ci
|
||||||
changes++
|
changes++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for ci := 0; ci < len(clusters); ci++ {
|
for ci := 0; ci < len(cc); ci++ {
|
||||||
if len(clusters[ci].Points) == 0 {
|
if len(cc[ci].Observations) == 0 {
|
||||||
// During the iterations, if any of the cluster centers has no
|
// During the iterations, if any of the cluster centers has no
|
||||||
// data points associated with it, assign a random data point
|
// data points associated with it, assign a random data point
|
||||||
// to it.
|
// to it.
|
||||||
@@ -101,20 +79,20 @@ func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) {
|
|||||||
// find a cluster with at least two data points, otherwise
|
// find a cluster with at least two data points, otherwise
|
||||||
// we're just emptying one cluster to fill another
|
// we're just emptying one cluster to fill another
|
||||||
ri = rand.Intn(len(dataset))
|
ri = rand.Intn(len(dataset))
|
||||||
if len(clusters[points[ri]].Points) > 1 {
|
if len(cc[points[ri]].Observations) > 1 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
clusters[ci].Points = append(clusters[ci].Points, dataset[ri])
|
cc[ci].Append(dataset[ri])
|
||||||
points[ri] = ci
|
points[ri] = ci
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if changes > 0 {
|
if changes > 0 {
|
||||||
clusters.recenter()
|
cc.Recenter()
|
||||||
}
|
}
|
||||||
if m.plotter != nil {
|
if m.plotter != nil {
|
||||||
m.plotter.Plot(clusters, i)
|
m.plotter.Plot(cc, i)
|
||||||
}
|
}
|
||||||
if i == m.iterationThreshold ||
|
if i == m.iterationThreshold ||
|
||||||
changes < int(float64(len(dataset))*m.deltaThreshold) {
|
changes < int(float64(len(dataset))*m.deltaThreshold) {
|
||||||
@@ -123,5 +101,5 @@ func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return clusters, nil
|
return cc, nil
|
||||||
}
|
}
|
||||||
|
@@ -3,6 +3,8 @@ package kmeans
|
|||||||
import (
|
import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/muesli/clusters"
|
||||||
)
|
)
|
||||||
|
|
||||||
var RANDOM_SEED = int64(42)
|
var RANDOM_SEED = int64(42)
|
||||||
@@ -21,13 +23,14 @@ func TestNewErrors(t *testing.T) {
|
|||||||
|
|
||||||
func TestPartitioningError(t *testing.T) {
|
func TestPartitioningError(t *testing.T) {
|
||||||
km := New()
|
km := New()
|
||||||
if _, err := km.Partition(Points{}, 1); err == nil {
|
d := clusters.Observations{}
|
||||||
|
if _, err := km.Partition(d, 1); err == nil {
|
||||||
t.Errorf("Expected error partitioning with empty data set, got nil")
|
t.Errorf("Expected error partitioning with empty data set, got nil")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d := Points{
|
d = clusters.Observations{
|
||||||
Point{
|
clusters.Coordinates{
|
||||||
0.1,
|
0.1,
|
||||||
0.1,
|
0.1,
|
||||||
},
|
},
|
||||||
@@ -44,10 +47,10 @@ func TestPartitioningError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDimensions(t *testing.T) {
|
func TestDimensions(t *testing.T) {
|
||||||
var d Points
|
var d clusters.Observations
|
||||||
for x := 0; x < 255; x += 32 {
|
for x := 0; x < 255; x += 32 {
|
||||||
for y := 0; y < 255; y += 32 {
|
for y := 0; y < 255; y += 32 {
|
||||||
d = append(d, Point{
|
d = append(d, clusters.Coordinates{
|
||||||
float64(x) / 255.0,
|
float64(x) / 255.0,
|
||||||
float64(y) / 255.0,
|
float64(y) / 255.0,
|
||||||
})
|
})
|
||||||
@@ -69,10 +72,10 @@ func TestDimensions(t *testing.T) {
|
|||||||
|
|
||||||
func benchmarkPartition(size, partitions int, b *testing.B) {
|
func benchmarkPartition(size, partitions int, b *testing.B) {
|
||||||
rand.Seed(RANDOM_SEED)
|
rand.Seed(RANDOM_SEED)
|
||||||
var d Points
|
var d clusters.Observations
|
||||||
|
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
d = append(d, Point{
|
d = append(d, clusters.Coordinates{
|
||||||
rand.Float64(),
|
rand.Float64(),
|
||||||
rand.Float64(),
|
rand.Float64(),
|
||||||
})
|
})
|
||||||
|
16
plotter.go
16
plotter.go
@@ -5,13 +5,15 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
|
||||||
|
"github.com/muesli/clusters"
|
||||||
|
|
||||||
"github.com/wcharczuk/go-chart"
|
"github.com/wcharczuk/go-chart"
|
||||||
"github.com/wcharczuk/go-chart/drawing"
|
"github.com/wcharczuk/go-chart/drawing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// The Plotter interface lets you implement your own plotters
|
// The Plotter interface lets you implement your own plotters
|
||||||
type Plotter interface {
|
type Plotter interface {
|
||||||
Plot(clusters Clusters, iteration int)
|
Plot(cc clusters.Clusters, iteration int)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SimplePlotter is the default standard plotter for 2-dimensional data sets
|
// SimplePlotter is the default standard plotter for 2-dimensional data sets
|
||||||
@@ -33,19 +35,19 @@ var colors = []drawing.Color{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Plot draw a 2-dimensional data set into a PNG file named {iteration}.png
|
// Plot draw a 2-dimensional data set into a PNG file named {iteration}.png
|
||||||
func (p SimplePlotter) Plot(clusters Clusters, iteration int) {
|
func (p SimplePlotter) Plot(cc clusters.Clusters, iteration int) {
|
||||||
var series []chart.Series
|
var series []chart.Series
|
||||||
|
|
||||||
// draw data points
|
// draw data points
|
||||||
for i, c := range clusters {
|
for i, c := range cc {
|
||||||
series = append(series, chart.ContinuousSeries{
|
series = append(series, chart.ContinuousSeries{
|
||||||
Style: chart.Style{
|
Style: chart.Style{
|
||||||
Show: true,
|
Show: true,
|
||||||
StrokeWidth: chart.Disabled,
|
StrokeWidth: chart.Disabled,
|
||||||
DotColor: colors[i%len(colors)],
|
DotColor: colors[i%len(colors)],
|
||||||
DotWidth: 8},
|
DotWidth: 8},
|
||||||
XValues: c.pointsInDimension(0),
|
XValues: c.PointsInDimension(0),
|
||||||
YValues: c.pointsInDimension(1),
|
YValues: c.PointsInDimension(1),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,8 +59,8 @@ func (p SimplePlotter) Plot(clusters Clusters, iteration int) {
|
|||||||
DotColor: drawing.ColorBlack,
|
DotColor: drawing.ColorBlack,
|
||||||
DotWidth: 16,
|
DotWidth: 16,
|
||||||
},
|
},
|
||||||
XValues: clusters.centersInDimension(0),
|
XValues: cc.CentersInDimension(0),
|
||||||
YValues: clusters.centersInDimension(1),
|
YValues: cc.CentersInDimension(1),
|
||||||
})
|
})
|
||||||
|
|
||||||
graph := chart.Chart{
|
graph := chart.Chart{
|
||||||
|
42
point.go
42
point.go
@@ -1,42 +0,0 @@
|
|||||||
package kmeans
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Point is a data point (float64 between 0.0 and 1.0) in n dimensions
|
|
||||||
type Point []float64
|
|
||||||
|
|
||||||
// Points is a slice of points
|
|
||||||
type Points []Point
|
|
||||||
|
|
||||||
// Distance returns the euclidean distance between two data points
|
|
||||||
func (p Point) Distance(p2 Point) float64 {
|
|
||||||
var r float64
|
|
||||||
for i, v := range p {
|
|
||||||
r += math.Pow(v-p2[i], 2)
|
|
||||||
}
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mean returns the mean point of p
|
|
||||||
func (p Points) Mean() (Point, error) {
|
|
||||||
var l = len(p)
|
|
||||||
if l == 0 {
|
|
||||||
return Point{}, fmt.Errorf("there is no mean for an empty set of points")
|
|
||||||
}
|
|
||||||
|
|
||||||
c := make([]float64, len(p[0]))
|
|
||||||
for _, point := range p {
|
|
||||||
for j, v := range point {
|
|
||||||
c[j] += v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var point Point
|
|
||||||
for _, v := range c {
|
|
||||||
point = append(point, v/float64(l))
|
|
||||||
}
|
|
||||||
return point, nil
|
|
||||||
}
|
|
Reference in New Issue
Block a user