From a2e4fa300cbe486e774660b480bcbf22764c88a7 Mon Sep 17 00:00:00 2001 From: Christian Muehlhaeuser Date: Sun, 27 May 2018 22:06:57 +0200 Subject: [PATCH] The Plotter interface lets you attach custom plotters --- README.md | 6 +++-- cluster.go | 44 ++++++++++++++++++++++++----------- examples/colorpalette/main.go | 2 +- kmeans.go | 12 +++++----- draw.go => plotter.go | 27 ++++++++------------- 5 files changed, 51 insertions(+), 40 deletions(-) rename draw.go => plotter.go (84%) diff --git a/README.md b/README.md index 1ad0c34..275f6cc 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ threshold. With the following options the algorithm finishes when less than 5% of the data points shifted their cluster assignment in the last iteration: ```go -km, err := kmeans.NewWithOptions(0.05, false) +km, err := kmeans.NewWithOptions(0.05, nil) ``` The default setting for the delta threshold is 0.01 (1%). @@ -68,11 +68,13 @@ If you are working with two-dimensional data sets, kmeans can generate beautiful graphs (like the one above) for each iteration of the algorithm: ```go -km, err := kmeans.NewWithOptions(0.01, true) +km, err := kmeans.NewWithOptions(0.01, kmeans.SimplePlotter{}) ``` Careful: this will generate PNGs in your current working directory. +You can write your own plotters by implementing the `kmeans.Plotter` interface. + ## Development [![GoDoc](https://godoc.org/github.com/golang/gddo?status.svg)](https://godoc.org/github.com/muesli/kmeans) diff --git a/cluster.go b/cluster.go index 3eeb769..d2b02f2 100644 --- a/cluster.go +++ b/cluster.go @@ -9,6 +9,23 @@ type Cluster struct { // Clusters is a slice of clusters type Clusters []Cluster +// Nearest returns the index of the cluster nearest to point +func (c Clusters) Nearest(point Point) int { + var dist float64 + var ci int + + // Find the nearest cluster for this data point + for i, cluster := range c { + d := point.Distance(cluster.Center) + if dist == 0 || d < dist { + dist = d + ci = i + } + } + + return ci +} + // recenter recenters a cluster func (c *Cluster) recenter() { center, err := c.Points.Mean() @@ -33,19 +50,18 @@ func (c Clusters) reset() { } } -// Nearest returns the index of the cluster nearest to point -func (c Clusters) Nearest(point Point) int { - var dist float64 - var ci int - - // Find the nearest cluster for this data point - for i, cluster := range c { - d := point.Distance(cluster.Center) - if dist == 0 || d < dist { - dist = d - ci = i - } +func (cluster *Cluster) pointsInDimension(n int) []float64 { + var v []float64 + for _, p := range cluster.Points { + v = append(v, p[n]) } - - return ci + return v +} + +func (clusters Clusters) centersInDimension(n int) []float64 { + var v []float64 + for _, c := range clusters { + v = append(v, c.Center[n]) + } + return v } diff --git a/examples/colorpalette/main.go b/examples/colorpalette/main.go index 70e1822..3a40d9d 100644 --- a/examples/colorpalette/main.go +++ b/examples/colorpalette/main.go @@ -52,7 +52,7 @@ func main() { buffer.Write([]byte(header)) // Enable graph generation (.png files) for each iteration - km, _ := kmeans.NewWithOptions(0.01, true) + km, _ := kmeans.NewWithOptions(0.01, kmeans.SimplePlotter{}) // Partition the color space into 16 clusters (palette colors) clusters, _ := km.Partition(d, 16) diff --git a/kmeans.go b/kmeans.go index 87b3687..4f99225 100644 --- a/kmeans.go +++ b/kmeans.go @@ -11,27 +11,27 @@ import ( // Kmeans configuration/option struct type Kmeans struct { // when Debug is enabled, graphs are generated after each iteration - debug bool + plotter Plotter // DeltaThreshold (in percent between 0.0 and 0.1) aborts processing if // less than n% of data points shifted clusters in the last iteration deltaThreshold float64 } // NewWithOptions returns a Kmeans configuration struct with custom settings -func NewWithOptions(deltaThreshold float64, debug bool) (Kmeans, error) { +func NewWithOptions(deltaThreshold float64, plotter Plotter) (Kmeans, error) { if deltaThreshold <= 0.0 || deltaThreshold >= 1.0 { return Kmeans{}, fmt.Errorf("threshold is out of bounds (must be >0.0 and <1.0, in percent)") } return Kmeans{ - debug: debug, + plotter: plotter, deltaThreshold: deltaThreshold, }, nil } // New returns a Kmeans configuration struct with default settings func New() Kmeans { - m, _ := NewWithOptions(0.01, false) + m, _ := NewWithOptions(0.01, nil) return m } @@ -100,8 +100,8 @@ func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) { if changes > 0 { clusters.recenter() } - if m.debug { - draw(clusters, i) + if m.plotter != nil { + m.plotter.Plot(clusters, i) } if changes < int(float64(len(dataset))*m.deltaThreshold) { // fmt.Println("Aborting:", changes, int(float64(len(dataset))*m.TerminationThreshold)) diff --git a/draw.go b/plotter.go similarity index 84% rename from draw.go rename to plotter.go index e9d50d9..c3794b1 100644 --- a/draw.go +++ b/plotter.go @@ -9,6 +9,14 @@ import ( "github.com/wcharczuk/go-chart/drawing" ) +type Plotter interface { + Plot(clusters Clusters, iteration int) +} + +// SimplePlotter is the default standard plotter for 2-dimensional data sets +type SimplePlotter struct { +} + var colors = []drawing.Color{ drawing.ColorFromHex("f92672"), drawing.ColorFromHex("89bdff"), @@ -22,23 +30,8 @@ var colors = []drawing.Color{ drawing.ColorFromHex("dcc060"), } -func (cluster *Cluster) pointsInDimension(n int) []float64 { - var v []float64 - for _, p := range cluster.Points { - v = append(v, p[n]) - } - return v -} - -func (clusters Clusters) centersInDimension(n int) []float64 { - var v []float64 - for _, c := range clusters { - v = append(v, c.Center[n]) - } - return v -} - -func draw(clusters Clusters, iteration int) { +// Plot draw a 2-dimensional data set into a PNG file named {iteration}.png +func (p SimplePlotter) Plot(clusters Clusters, iteration int) { var series []chart.Series // draw data points