mirror of
https://github.com/gonum/gonum.git
synced 2025-10-08 16:40:06 +08:00
146 lines
4.0 KiB
Go
146 lines
4.0 KiB
Go
// Copyright ©2016 The Gonum Authors. All rights reserved.
|
||
// Use of this source code is governed by a BSD-style
|
||
// license that can be found in the LICENSE file.
|
||
|
||
package distmv
|
||
|
||
import (
|
||
"math"
|
||
|
||
"golang.org/x/exp/rand"
|
||
|
||
"gonum.org/v1/gonum/floats"
|
||
"gonum.org/v1/gonum/mat"
|
||
"gonum.org/v1/gonum/stat/distuv"
|
||
)
|
||
|
||
// Dirichlet implements the Dirichlet probability distribution.
|
||
//
|
||
// The Dirichlet distribution is a continuous probability distribution that
|
||
// generates elements over the probability simplex, i.e. ||x||_1 = 1. The Dirichlet
|
||
// distribution is the conjugate prior to the categorical distribution and the
|
||
// multivariate version of the beta distribution. The probability of a point x is
|
||
// 1/Beta(α) \prod_i x_i^(α_i - 1)
|
||
// where Beta(α) is the multivariate Beta function (see the mathext package).
|
||
//
|
||
// For more information see https://en.wikipedia.org/wiki/Dirichlet_distribution
|
||
type Dirichlet struct {
|
||
alpha []float64
|
||
dim int
|
||
src rand.Source
|
||
|
||
lbeta float64
|
||
sumAlpha float64
|
||
}
|
||
|
||
// NewDirichlet creates a new dirichlet distribution with the given parameters alpha.
|
||
// NewDirichlet will panic if len(alpha) == 0, or if any alpha is <= 0.
|
||
func NewDirichlet(alpha []float64, src rand.Source) *Dirichlet {
|
||
dim := len(alpha)
|
||
if dim == 0 {
|
||
panic(badZeroDimension)
|
||
}
|
||
for _, v := range alpha {
|
||
if v <= 0 {
|
||
panic("dirichlet: non-positive alpha")
|
||
}
|
||
}
|
||
a := make([]float64, len(alpha))
|
||
copy(a, alpha)
|
||
d := &Dirichlet{
|
||
alpha: a,
|
||
dim: dim,
|
||
src: src,
|
||
}
|
||
d.lbeta, d.sumAlpha = d.genLBeta(a)
|
||
return d
|
||
}
|
||
|
||
// CovarianceMatrix returns the covariance matrix of the distribution. Upon
|
||
// return, the value at element {i, j} of the covariance matrix is equal to
|
||
// the covariance of the i^th and j^th variables.
|
||
// covariance(i, j) = E[(x_i - E[x_i])(x_j - E[x_j])]
|
||
// If the input matrix is nil a new matrix is allocated, otherwise the result
|
||
// is stored in-place into the input.
|
||
func (d *Dirichlet) CovarianceMatrix(cov *mat.SymDense) *mat.SymDense {
|
||
if cov == nil {
|
||
cov = mat.NewSymDense(d.Dim(), nil)
|
||
} else if cov.Symmetric() == 0 {
|
||
*cov = *(cov.GrowSquare(d.dim).(*mat.SymDense))
|
||
} else if cov.Symmetric() != d.dim {
|
||
panic("normal: input matrix size mismatch")
|
||
}
|
||
scale := 1 / (d.sumAlpha * d.sumAlpha * (d.sumAlpha + 1))
|
||
for i := 0; i < d.dim; i++ {
|
||
ai := d.alpha[i]
|
||
v := ai * (d.sumAlpha - ai) * scale
|
||
cov.SetSym(i, i, v)
|
||
for j := i + 1; j < d.dim; j++ {
|
||
aj := d.alpha[j]
|
||
v := -ai * aj * scale
|
||
cov.SetSym(i, j, v)
|
||
}
|
||
}
|
||
return cov
|
||
}
|
||
|
||
// genLBeta computes the generalized LBeta function.
|
||
func (d *Dirichlet) genLBeta(alpha []float64) (lbeta, sumAlpha float64) {
|
||
for _, alpha := range d.alpha {
|
||
lg, _ := math.Lgamma(alpha)
|
||
lbeta += lg
|
||
sumAlpha += alpha
|
||
}
|
||
lg, _ := math.Lgamma(sumAlpha)
|
||
return lbeta - lg, sumAlpha
|
||
}
|
||
|
||
// Dim returns the dimension of the distribution.
|
||
func (d *Dirichlet) Dim() int {
|
||
return d.dim
|
||
}
|
||
|
||
// LogProb computes the log of the pdf of the point x.
|
||
//
|
||
// It does not check that ||x||_1 = 1.
|
||
func (d *Dirichlet) LogProb(x []float64) float64 {
|
||
dim := d.dim
|
||
if len(x) != dim {
|
||
panic(badSizeMismatch)
|
||
}
|
||
var lprob float64
|
||
for i, x := range x {
|
||
lprob += (d.alpha[i] - 1) * math.Log(x)
|
||
}
|
||
lprob -= d.lbeta
|
||
return lprob
|
||
}
|
||
|
||
// Mean returns the mean of the probability distribution at x. If the
|
||
// input argument is nil, a new slice will be allocated, otherwise the result
|
||
// will be put in-place into the receiver.
|
||
func (d *Dirichlet) Mean(x []float64) []float64 {
|
||
x = reuseAs(x, d.dim)
|
||
copy(x, d.alpha)
|
||
floats.Scale(1/d.sumAlpha, x)
|
||
return x
|
||
}
|
||
|
||
// Prob computes the value of the probability density function at x.
|
||
func (d *Dirichlet) Prob(x []float64) float64 {
|
||
return math.Exp(d.LogProb(x))
|
||
}
|
||
|
||
// Rand generates a random number according to the distributon.
|
||
// If the input slice is nil, new memory is allocated, otherwise the result is stored
|
||
// in place.
|
||
func (d *Dirichlet) Rand(x []float64) []float64 {
|
||
x = reuseAs(x, d.dim)
|
||
for i := range x {
|
||
x[i] = distuv.Gamma{Alpha: d.alpha[i], Beta: 1, Src: d.src}.Rand()
|
||
}
|
||
sum := floats.Sum(x)
|
||
floats.Scale(1/sum, x)
|
||
return x
|
||
}
|