Files
gonum/stat/distmv/dirichlet.go
2018-05-03 07:40:18 +09:30

146 lines
4.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright ©2016 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package distmv
import (
"math"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat/distuv"
)
// Dirichlet implements the Dirichlet probability distribution.
//
// The Dirichlet distribution is a continuous probability distribution that
// generates elements over the probability simplex, i.e. ||x||_1 = 1. The Dirichlet
// distribution is the conjugate prior to the categorical distribution and the
// multivariate version of the beta distribution. The probability of a point x is
// 1/Beta(α) \prod_i x_i^(α_i - 1)
// where Beta(α) is the multivariate Beta function (see the mathext package).
//
// For more information see https://en.wikipedia.org/wiki/Dirichlet_distribution
type Dirichlet struct {
alpha []float64
dim int
src rand.Source
lbeta float64
sumAlpha float64
}
// NewDirichlet creates a new dirichlet distribution with the given parameters alpha.
// NewDirichlet will panic if len(alpha) == 0, or if any alpha is <= 0.
func NewDirichlet(alpha []float64, src rand.Source) *Dirichlet {
dim := len(alpha)
if dim == 0 {
panic(badZeroDimension)
}
for _, v := range alpha {
if v <= 0 {
panic("dirichlet: non-positive alpha")
}
}
a := make([]float64, len(alpha))
copy(a, alpha)
d := &Dirichlet{
alpha: a,
dim: dim,
src: src,
}
d.lbeta, d.sumAlpha = d.genLBeta(a)
return d
}
// CovarianceMatrix returns the covariance matrix of the distribution. Upon
// return, the value at element {i, j} of the covariance matrix is equal to
// the covariance of the i^th and j^th variables.
// covariance(i, j) = E[(x_i - E[x_i])(x_j - E[x_j])]
// If the input matrix is nil a new matrix is allocated, otherwise the result
// is stored in-place into the input.
func (d *Dirichlet) CovarianceMatrix(cov *mat.SymDense) *mat.SymDense {
if cov == nil {
cov = mat.NewSymDense(d.Dim(), nil)
} else if cov.Symmetric() == 0 {
*cov = *(cov.GrowSquare(d.dim).(*mat.SymDense))
} else if cov.Symmetric() != d.dim {
panic("normal: input matrix size mismatch")
}
scale := 1 / (d.sumAlpha * d.sumAlpha * (d.sumAlpha + 1))
for i := 0; i < d.dim; i++ {
ai := d.alpha[i]
v := ai * (d.sumAlpha - ai) * scale
cov.SetSym(i, i, v)
for j := i + 1; j < d.dim; j++ {
aj := d.alpha[j]
v := -ai * aj * scale
cov.SetSym(i, j, v)
}
}
return cov
}
// genLBeta computes the generalized LBeta function.
func (d *Dirichlet) genLBeta(alpha []float64) (lbeta, sumAlpha float64) {
for _, alpha := range d.alpha {
lg, _ := math.Lgamma(alpha)
lbeta += lg
sumAlpha += alpha
}
lg, _ := math.Lgamma(sumAlpha)
return lbeta - lg, sumAlpha
}
// Dim returns the dimension of the distribution.
func (d *Dirichlet) Dim() int {
return d.dim
}
// LogProb computes the log of the pdf of the point x.
//
// It does not check that ||x||_1 = 1.
func (d *Dirichlet) LogProb(x []float64) float64 {
dim := d.dim
if len(x) != dim {
panic(badSizeMismatch)
}
var lprob float64
for i, x := range x {
lprob += (d.alpha[i] - 1) * math.Log(x)
}
lprob -= d.lbeta
return lprob
}
// Mean returns the mean of the probability distribution at x. If the
// input argument is nil, a new slice will be allocated, otherwise the result
// will be put in-place into the receiver.
func (d *Dirichlet) Mean(x []float64) []float64 {
x = reuseAs(x, d.dim)
copy(x, d.alpha)
floats.Scale(1/d.sumAlpha, x)
return x
}
// Prob computes the value of the probability density function at x.
func (d *Dirichlet) Prob(x []float64) float64 {
return math.Exp(d.LogProb(x))
}
// Rand generates a random number according to the distributon.
// If the input slice is nil, new memory is allocated, otherwise the result is stored
// in place.
func (d *Dirichlet) Rand(x []float64) []float64 {
x = reuseAs(x, d.dim)
for i := range x {
x[i] = distuv.Gamma{Alpha: d.alpha[i], Beta: 1, Src: d.src}.Rand()
}
sum := floats.Sum(x)
floats.Scale(1/sum, x)
return x
}