mirror of
https://github.com/gonum/gonum.git
synced 2025-10-10 01:20:14 +08:00
Added LogSumExp
This commit is contained in:
41
sliceops.go
41
sliceops.go
@@ -163,3 +163,44 @@ func Norm(s []float64, L float64) (norm float64) {
|
||||
}
|
||||
return math.Pow(norm, 1/L)
|
||||
}
|
||||
|
||||
// Adds a constant to all of the values in s
|
||||
func AddConst(s []float64, c float64) {
|
||||
for i := range s {
|
||||
s[i] += c
|
||||
}
|
||||
}
|
||||
|
||||
// Multiplies every element in s by a constant
|
||||
func MulConst(s []float64, c float64) {
|
||||
for i := range s {
|
||||
s[i] *= c
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the log of the sum of the exponentials of the values in s
|
||||
func LogSumExp(s []float64) (logsumexp float64) {
|
||||
// Want to do this in a numerically stable way which avoids
|
||||
// overflow and underflow
|
||||
// TODO: Add in special case for two values
|
||||
|
||||
// First, find the maximum value in the slice.
|
||||
minval, _ := Max(s)
|
||||
if math.IsInf(minval, 0) {
|
||||
// If it's infinity eitherway, the logsumexp will be infinity as well
|
||||
// returning now avoids NaNs
|
||||
return minval
|
||||
}
|
||||
// Subtract off the largest value, so the largest value in
|
||||
// the new slice is 0
|
||||
AddConst(s, -minval)
|
||||
defer AddConst(s, minval) // make sure we add it back on at the end
|
||||
|
||||
// compute the sumexp part
|
||||
for _, val := range s {
|
||||
logsumexp += math.Exp(val)
|
||||
}
|
||||
// Take the log and add back on the constant taken out
|
||||
logsumexp = math.Log(logsumexp) + minval
|
||||
return logsumexp
|
||||
}
|
||||
|
Reference in New Issue
Block a user