The Plotter interface lets you attach custom plotters

This commit is contained in:
Christian Muehlhaeuser
2018-05-27 22:06:57 +02:00
parent 43c4b4685f
commit a2e4fa300c
5 changed files with 51 additions and 40 deletions

View File

@@ -59,7 +59,7 @@ threshold. With the following options the algorithm finishes when less than 5%
of the data points shifted their cluster assignment in the last iteration:
```go
km, err := kmeans.NewWithOptions(0.05, false)
km, err := kmeans.NewWithOptions(0.05, nil)
```
The default setting for the delta threshold is 0.01 (1%).
@@ -68,11 +68,13 @@ If you are working with two-dimensional data sets, kmeans can generate
beautiful graphs (like the one above) for each iteration of the algorithm:
```go
km, err := kmeans.NewWithOptions(0.01, true)
km, err := kmeans.NewWithOptions(0.01, kmeans.SimplePlotter{})
```
Careful: this will generate PNGs in your current working directory.
You can write your own plotters by implementing the `kmeans.Plotter` interface.
## Development
[![GoDoc](https://godoc.org/github.com/golang/gddo?status.svg)](https://godoc.org/github.com/muesli/kmeans)

View File

@@ -9,6 +9,23 @@ type Cluster struct {
// Clusters is a slice of clusters
type Clusters []Cluster
// Nearest returns the index of the cluster nearest to point
func (c Clusters) Nearest(point Point) int {
var dist float64
var ci int
// Find the nearest cluster for this data point
for i, cluster := range c {
d := point.Distance(cluster.Center)
if dist == 0 || d < dist {
dist = d
ci = i
}
}
return ci
}
// recenter recenters a cluster
func (c *Cluster) recenter() {
center, err := c.Points.Mean()
@@ -33,19 +50,18 @@ func (c Clusters) reset() {
}
}
// Nearest returns the index of the cluster nearest to point
func (c Clusters) Nearest(point Point) int {
var dist float64
var ci int
// Find the nearest cluster for this data point
for i, cluster := range c {
d := point.Distance(cluster.Center)
if dist == 0 || d < dist {
dist = d
ci = i
func (cluster *Cluster) pointsInDimension(n int) []float64 {
var v []float64
for _, p := range cluster.Points {
v = append(v, p[n])
}
return v
}
return ci
func (clusters Clusters) centersInDimension(n int) []float64 {
var v []float64
for _, c := range clusters {
v = append(v, c.Center[n])
}
return v
}

View File

@@ -52,7 +52,7 @@ func main() {
buffer.Write([]byte(header))
// Enable graph generation (.png files) for each iteration
km, _ := kmeans.NewWithOptions(0.01, true)
km, _ := kmeans.NewWithOptions(0.01, kmeans.SimplePlotter{})
// Partition the color space into 16 clusters (palette colors)
clusters, _ := km.Partition(d, 16)

View File

@@ -11,27 +11,27 @@ import (
// Kmeans configuration/option struct
type Kmeans struct {
// when Debug is enabled, graphs are generated after each iteration
debug bool
plotter Plotter
// DeltaThreshold (in percent between 0.0 and 0.1) aborts processing if
// less than n% of data points shifted clusters in the last iteration
deltaThreshold float64
}
// NewWithOptions returns a Kmeans configuration struct with custom settings
func NewWithOptions(deltaThreshold float64, debug bool) (Kmeans, error) {
func NewWithOptions(deltaThreshold float64, plotter Plotter) (Kmeans, error) {
if deltaThreshold <= 0.0 || deltaThreshold >= 1.0 {
return Kmeans{}, fmt.Errorf("threshold is out of bounds (must be >0.0 and <1.0, in percent)")
}
return Kmeans{
debug: debug,
plotter: plotter,
deltaThreshold: deltaThreshold,
}, nil
}
// New returns a Kmeans configuration struct with default settings
func New() Kmeans {
m, _ := NewWithOptions(0.01, false)
m, _ := NewWithOptions(0.01, nil)
return m
}
@@ -100,8 +100,8 @@ func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) {
if changes > 0 {
clusters.recenter()
}
if m.debug {
draw(clusters, i)
if m.plotter != nil {
m.plotter.Plot(clusters, i)
}
if changes < int(float64(len(dataset))*m.deltaThreshold) {
// fmt.Println("Aborting:", changes, int(float64(len(dataset))*m.TerminationThreshold))

View File

@@ -9,6 +9,14 @@ import (
"github.com/wcharczuk/go-chart/drawing"
)
type Plotter interface {
Plot(clusters Clusters, iteration int)
}
// SimplePlotter is the default standard plotter for 2-dimensional data sets
type SimplePlotter struct {
}
var colors = []drawing.Color{
drawing.ColorFromHex("f92672"),
drawing.ColorFromHex("89bdff"),
@@ -22,23 +30,8 @@ var colors = []drawing.Color{
drawing.ColorFromHex("dcc060"),
}
func (cluster *Cluster) pointsInDimension(n int) []float64 {
var v []float64
for _, p := range cluster.Points {
v = append(v, p[n])
}
return v
}
func (clusters Clusters) centersInDimension(n int) []float64 {
var v []float64
for _, c := range clusters {
v = append(v, c.Center[n])
}
return v
}
func draw(clusters Clusters, iteration int) {
// Plot draw a 2-dimensional data set into a PNG file named {iteration}.png
func (p SimplePlotter) Plot(clusters Clusters, iteration int) {
var series []chart.Series
// draw data points