mirror of
https://github.com/muesli/kmeans.git
synced 2025-09-27 03:56:17 +08:00
The Plotter interface lets you attach custom plotters
This commit is contained in:
@@ -59,7 +59,7 @@ threshold. With the following options the algorithm finishes when less than 5%
|
|||||||
of the data points shifted their cluster assignment in the last iteration:
|
of the data points shifted their cluster assignment in the last iteration:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
km, err := kmeans.NewWithOptions(0.05, false)
|
km, err := kmeans.NewWithOptions(0.05, nil)
|
||||||
```
|
```
|
||||||
|
|
||||||
The default setting for the delta threshold is 0.01 (1%).
|
The default setting for the delta threshold is 0.01 (1%).
|
||||||
@@ -68,11 +68,13 @@ If you are working with two-dimensional data sets, kmeans can generate
|
|||||||
beautiful graphs (like the one above) for each iteration of the algorithm:
|
beautiful graphs (like the one above) for each iteration of the algorithm:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
km, err := kmeans.NewWithOptions(0.01, true)
|
km, err := kmeans.NewWithOptions(0.01, kmeans.SimplePlotter{})
|
||||||
```
|
```
|
||||||
|
|
||||||
Careful: this will generate PNGs in your current working directory.
|
Careful: this will generate PNGs in your current working directory.
|
||||||
|
|
||||||
|
You can write your own plotters by implementing the `kmeans.Plotter` interface.
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
[](https://godoc.org/github.com/muesli/kmeans)
|
[](https://godoc.org/github.com/muesli/kmeans)
|
||||||
|
44
cluster.go
44
cluster.go
@@ -9,6 +9,23 @@ type Cluster struct {
|
|||||||
// Clusters is a slice of clusters
|
// Clusters is a slice of clusters
|
||||||
type Clusters []Cluster
|
type Clusters []Cluster
|
||||||
|
|
||||||
|
// Nearest returns the index of the cluster nearest to point
|
||||||
|
func (c Clusters) Nearest(point Point) int {
|
||||||
|
var dist float64
|
||||||
|
var ci int
|
||||||
|
|
||||||
|
// Find the nearest cluster for this data point
|
||||||
|
for i, cluster := range c {
|
||||||
|
d := point.Distance(cluster.Center)
|
||||||
|
if dist == 0 || d < dist {
|
||||||
|
dist = d
|
||||||
|
ci = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ci
|
||||||
|
}
|
||||||
|
|
||||||
// recenter recenters a cluster
|
// recenter recenters a cluster
|
||||||
func (c *Cluster) recenter() {
|
func (c *Cluster) recenter() {
|
||||||
center, err := c.Points.Mean()
|
center, err := c.Points.Mean()
|
||||||
@@ -33,19 +50,18 @@ func (c Clusters) reset() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nearest returns the index of the cluster nearest to point
|
func (cluster *Cluster) pointsInDimension(n int) []float64 {
|
||||||
func (c Clusters) Nearest(point Point) int {
|
var v []float64
|
||||||
var dist float64
|
for _, p := range cluster.Points {
|
||||||
var ci int
|
v = append(v, p[n])
|
||||||
|
|
||||||
// Find the nearest cluster for this data point
|
|
||||||
for i, cluster := range c {
|
|
||||||
d := point.Distance(cluster.Center)
|
|
||||||
if dist == 0 || d < dist {
|
|
||||||
dist = d
|
|
||||||
ci = i
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return v
|
||||||
return ci
|
}
|
||||||
|
|
||||||
|
func (clusters Clusters) centersInDimension(n int) []float64 {
|
||||||
|
var v []float64
|
||||||
|
for _, c := range clusters {
|
||||||
|
v = append(v, c.Center[n])
|
||||||
|
}
|
||||||
|
return v
|
||||||
}
|
}
|
||||||
|
@@ -52,7 +52,7 @@ func main() {
|
|||||||
buffer.Write([]byte(header))
|
buffer.Write([]byte(header))
|
||||||
|
|
||||||
// Enable graph generation (.png files) for each iteration
|
// Enable graph generation (.png files) for each iteration
|
||||||
km, _ := kmeans.NewWithOptions(0.01, true)
|
km, _ := kmeans.NewWithOptions(0.01, kmeans.SimplePlotter{})
|
||||||
|
|
||||||
// Partition the color space into 16 clusters (palette colors)
|
// Partition the color space into 16 clusters (palette colors)
|
||||||
clusters, _ := km.Partition(d, 16)
|
clusters, _ := km.Partition(d, 16)
|
||||||
|
12
kmeans.go
12
kmeans.go
@@ -11,27 +11,27 @@ import (
|
|||||||
// Kmeans configuration/option struct
|
// Kmeans configuration/option struct
|
||||||
type Kmeans struct {
|
type Kmeans struct {
|
||||||
// when Debug is enabled, graphs are generated after each iteration
|
// when Debug is enabled, graphs are generated after each iteration
|
||||||
debug bool
|
plotter Plotter
|
||||||
// DeltaThreshold (in percent between 0.0 and 0.1) aborts processing if
|
// DeltaThreshold (in percent between 0.0 and 0.1) aborts processing if
|
||||||
// less than n% of data points shifted clusters in the last iteration
|
// less than n% of data points shifted clusters in the last iteration
|
||||||
deltaThreshold float64
|
deltaThreshold float64
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWithOptions returns a Kmeans configuration struct with custom settings
|
// NewWithOptions returns a Kmeans configuration struct with custom settings
|
||||||
func NewWithOptions(deltaThreshold float64, debug bool) (Kmeans, error) {
|
func NewWithOptions(deltaThreshold float64, plotter Plotter) (Kmeans, error) {
|
||||||
if deltaThreshold <= 0.0 || deltaThreshold >= 1.0 {
|
if deltaThreshold <= 0.0 || deltaThreshold >= 1.0 {
|
||||||
return Kmeans{}, fmt.Errorf("threshold is out of bounds (must be >0.0 and <1.0, in percent)")
|
return Kmeans{}, fmt.Errorf("threshold is out of bounds (must be >0.0 and <1.0, in percent)")
|
||||||
}
|
}
|
||||||
|
|
||||||
return Kmeans{
|
return Kmeans{
|
||||||
debug: debug,
|
plotter: plotter,
|
||||||
deltaThreshold: deltaThreshold,
|
deltaThreshold: deltaThreshold,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a Kmeans configuration struct with default settings
|
// New returns a Kmeans configuration struct with default settings
|
||||||
func New() Kmeans {
|
func New() Kmeans {
|
||||||
m, _ := NewWithOptions(0.01, false)
|
m, _ := NewWithOptions(0.01, nil)
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,8 +100,8 @@ func (m Kmeans) Partition(dataset Points, k int) (Clusters, error) {
|
|||||||
if changes > 0 {
|
if changes > 0 {
|
||||||
clusters.recenter()
|
clusters.recenter()
|
||||||
}
|
}
|
||||||
if m.debug {
|
if m.plotter != nil {
|
||||||
draw(clusters, i)
|
m.plotter.Plot(clusters, i)
|
||||||
}
|
}
|
||||||
if changes < int(float64(len(dataset))*m.deltaThreshold) {
|
if changes < int(float64(len(dataset))*m.deltaThreshold) {
|
||||||
// fmt.Println("Aborting:", changes, int(float64(len(dataset))*m.TerminationThreshold))
|
// fmt.Println("Aborting:", changes, int(float64(len(dataset))*m.TerminationThreshold))
|
||||||
|
@@ -9,6 +9,14 @@ import (
|
|||||||
"github.com/wcharczuk/go-chart/drawing"
|
"github.com/wcharczuk/go-chart/drawing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Plotter interface {
|
||||||
|
Plot(clusters Clusters, iteration int)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SimplePlotter is the default standard plotter for 2-dimensional data sets
|
||||||
|
type SimplePlotter struct {
|
||||||
|
}
|
||||||
|
|
||||||
var colors = []drawing.Color{
|
var colors = []drawing.Color{
|
||||||
drawing.ColorFromHex("f92672"),
|
drawing.ColorFromHex("f92672"),
|
||||||
drawing.ColorFromHex("89bdff"),
|
drawing.ColorFromHex("89bdff"),
|
||||||
@@ -22,23 +30,8 @@ var colors = []drawing.Color{
|
|||||||
drawing.ColorFromHex("dcc060"),
|
drawing.ColorFromHex("dcc060"),
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cluster *Cluster) pointsInDimension(n int) []float64 {
|
// Plot draw a 2-dimensional data set into a PNG file named {iteration}.png
|
||||||
var v []float64
|
func (p SimplePlotter) Plot(clusters Clusters, iteration int) {
|
||||||
for _, p := range cluster.Points {
|
|
||||||
v = append(v, p[n])
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (clusters Clusters) centersInDimension(n int) []float64 {
|
|
||||||
var v []float64
|
|
||||||
for _, c := range clusters {
|
|
||||||
v = append(v, c.Center[n])
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func draw(clusters Clusters, iteration int) {
|
|
||||||
var series []chart.Series
|
var series []chart.Series
|
||||||
|
|
||||||
// draw data points
|
// draw data points
|
Reference in New Issue
Block a user