The Plotter interface lets you attach custom plotters

This commit is contained in:
Christian Muehlhaeuser
2018-05-27 22:06:57 +02:00
parent 43c4b4685f
commit a2e4fa300c
5 changed files with 51 additions and 40 deletions

View File

@@ -59,7 +59,7 @@ threshold. With the following options the algorithm finishes when less than 5%
of the data points shifted their cluster assignment in the last iteration: of the data points shifted their cluster assignment in the last iteration:
```go ```go
km, err := kmeans.NewWithOptions(0.05, false) km, err := kmeans.NewWithOptions(0.05, nil)
``` ```
The default setting for the delta threshold is 0.01 (1%). The default setting for the delta threshold is 0.01 (1%).
@@ -68,11 +68,13 @@ If you are working with two-dimensional data sets, kmeans can generate
beautiful graphs (like the one above) for each iteration of the algorithm: beautiful graphs (like the one above) for each iteration of the algorithm:
```go ```go
km, err := kmeans.NewWithOptions(0.01, true) km, err := kmeans.NewWithOptions(0.01, kmeans.SimplePlotter{})
``` ```
Careful: this will generate PNGs in your current working directory. Careful: this will generate PNGs in your current working directory.
You can write your own plotters by implementing the `kmeans.Plotter` interface.
## Development ## Development
[![GoDoc](https://godoc.org/github.com/golang/gddo?status.svg)](https://godoc.org/github.com/muesli/kmeans) [![GoDoc](https://godoc.org/github.com/golang/gddo?status.svg)](https://godoc.org/github.com/muesli/kmeans)

View File

@@ -9,6 +9,23 @@ type Cluster struct {
// Clusters is a slice of clusters // Clusters is a slice of clusters
type Clusters []Cluster type Clusters []Cluster
// Nearest returns the index of the cluster nearest to point
func (c Clusters) Nearest(point Point) int {
var dist float64
var ci int
// Find the nearest cluster for this data point
for i, cluster := range c {
d := point.Distance(cluster.Center)
if dist == 0 || d < dist {
dist = d
ci = i
}
}
return ci
}
// recenter recenters a cluster // recenter recenters a cluster
func (c *Cluster) recenter() { func (c *Cluster) recenter() {
center, err := c.Points.Mean() center, err := c.Points.Mean()
@@ -33,19 +50,18 @@ func (c Clusters) reset() {
} }
} }
// Nearest returns the index of the cluster nearest to point func (cluster *Cluster) pointsInDimension(n int) []float64 {
func (c Clusters) Nearest(point Point) int { var v []float64
var dist float64 for _, p := range cluster.Points {
var ci int v = append(v, p[n])
// Find the nearest cluster for this data point
for i, cluster := range c {
d := point.Distance(cluster.Center)
if dist == 0 || d < dist {
dist = d
ci = i
}
} }
return v
return ci }
func (clusters Clusters) centersInDimension(n int) []float64 {
var v []float64
for _, c := range clusters {
v = append(v, c.Center[n])
}
return v
} }

View File

@@ -52,7 +52,7 @@ func main() {
buffer.Write([]byte(header)) buffer.Write([]byte(header))
// Enable graph generation (.png files) for each iteration // Enable graph generation (.png files) for each iteration
km, _ := kmeans.NewWithOptions(0.01, true) km, _ := kmeans.NewWithOptions(0.01, kmeans.SimplePlotter{})
// Partition the color space into 16 clusters (palette colors) // Partition the color space into 16 clusters (palette colors)
clusters, _ := km.Partition(d, 16) clusters, _ := km.Partition(d, 16)

View File

@@ -11,27 +11,27 @@ import (
// Kmeans configuration/option struct // Kmeans configuration/option struct
type Kmeans struct { type Kmeans struct {
// when Debug is enabled, graphs are generated after each iteration // when Debug is enabled, graphs are generated after each iteration
debug bool plotter Plotter
// DeltaThreshold (in percent between 0.0 and 0.1) aborts processing if // DeltaThreshold (in percent between 0.0 and 0.1) aborts processing if
// less than n% of data points shifted clusters in the last iteration // less than n% of data points shifted clusters in the last iteration
deltaThreshold float64 deltaThreshold float64
} }
// NewWithOptions returns a Kmeans configuration struct with custom settings // NewWithOptions returns a Kmeans configuration struct with custom settings
func NewWithOptions(deltaThreshold float64, debug bool) (Kmeans, error) { func NewWithOptions(deltaThreshold float64, plotter Plotter) (Kmeans, error) {
if deltaThreshold <= 0.0 || deltaThreshold >= 1.0 { if deltaThreshold <= 0.0 || deltaThreshold >= 1.0 {
return Kmeans{}, fmt.Errorf("threshold is out of bounds (must be >0.0 and <1.0, in percent)") return Kmeans{}, fmt.Errorf("threshold is out of bounds (must be >0.0 and <1.0, in percent)")
} }
return Kmeans{ return Kmeans{
debug: debug, plotter: plotter,
deltaThreshold: deltaThreshold, deltaThreshold: deltaThreshold,
}, nil }, nil
} }
// New returns a Kmeans configuration struct with default settings // New returns a Kmeans configuration struct with default settings
func New() Kmeans { func New() Kmeans {
m, _ := NewWithOptions(0.01, false) m, _ := NewWithOptions(0.01, nil)
return m return m
} }
@@ -100,8 +100,8 @@ func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) {
if changes > 0 { if changes > 0 {
clusters.recenter() clusters.recenter()
} }
if m.debug { if m.plotter != nil {
draw(clusters, i) m.plotter.Plot(clusters, i)
} }
if changes < int(float64(len(dataset))*m.deltaThreshold) { if changes < int(float64(len(dataset))*m.deltaThreshold) {
// fmt.Println("Aborting:", changes, int(float64(len(dataset))*m.TerminationThreshold)) // fmt.Println("Aborting:", changes, int(float64(len(dataset))*m.TerminationThreshold))

View File

@@ -9,6 +9,14 @@ import (
"github.com/wcharczuk/go-chart/drawing" "github.com/wcharczuk/go-chart/drawing"
) )
type Plotter interface {
Plot(clusters Clusters, iteration int)
}
// SimplePlotter is the default standard plotter for 2-dimensional data sets
type SimplePlotter struct {
}
var colors = []drawing.Color{ var colors = []drawing.Color{
drawing.ColorFromHex("f92672"), drawing.ColorFromHex("f92672"),
drawing.ColorFromHex("89bdff"), drawing.ColorFromHex("89bdff"),
@@ -22,23 +30,8 @@ var colors = []drawing.Color{
drawing.ColorFromHex("dcc060"), drawing.ColorFromHex("dcc060"),
} }
func (cluster *Cluster) pointsInDimension(n int) []float64 { // Plot draw a 2-dimensional data set into a PNG file named {iteration}.png
var v []float64 func (p SimplePlotter) Plot(clusters Clusters, iteration int) {
for _, p := range cluster.Points {
v = append(v, p[n])
}
return v
}
func (clusters Clusters) centersInDimension(n int) []float64 {
var v []float64
for _, c := range clusters {
v = append(v, c.Center[n])
}
return v
}
func draw(clusters Clusters, iteration int) {
var series []chart.Series var series []chart.Series
// draw data points // draw data points