mirror of
https://github.com/gonum/gonum.git
synced 2025-09-30 13:02:20 +08:00
88 lines
1.7 KiB
Go
88 lines
1.7 KiB
Go
// Copyright ©2017 The gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package fd
|
|
|
|
import (
|
|
"gonum.org/v1/gonum/floats"
|
|
"gonum.org/v1/gonum/mat"
|
|
)
|
|
|
|
// ConstFunc is a constant function returning the value held by the type.
|
|
type ConstFunc float64
|
|
|
|
func (c ConstFunc) Func(x []float64) float64 {
|
|
return float64(c)
|
|
}
|
|
|
|
func (c ConstFunc) Grad(grad, x []float64) {
|
|
for i := range grad {
|
|
grad[i] = 0
|
|
}
|
|
}
|
|
|
|
func (c ConstFunc) Hess(dst mat.MutableSymmetric, x []float64) {
|
|
n := len(x)
|
|
for i := 0; i < n; i++ {
|
|
for j := i; j < n; j++ {
|
|
dst.SetSym(i, j, 0)
|
|
}
|
|
}
|
|
}
|
|
|
|
// LinearFunc is a linear function returning w*x+c.
|
|
type LinearFunc struct {
|
|
w []float64
|
|
c float64
|
|
}
|
|
|
|
func (l LinearFunc) Func(x []float64) float64 {
|
|
return floats.Dot(l.w, x) + l.c
|
|
}
|
|
|
|
func (l LinearFunc) Grad(grad, x []float64) {
|
|
copy(grad, l.w)
|
|
}
|
|
|
|
func (l LinearFunc) Hess(dst mat.MutableSymmetric, x []float64) {
|
|
n := len(x)
|
|
for i := 0; i < n; i++ {
|
|
for j := i; j < n; j++ {
|
|
dst.SetSym(i, j, 0)
|
|
}
|
|
}
|
|
}
|
|
|
|
// QuadFunc is a quadratic function returning 0.5*x'*a*x + b*x + c.
|
|
type QuadFunc struct {
|
|
a *mat.SymDense
|
|
b *mat.VecDense
|
|
c float64
|
|
}
|
|
|
|
func (q QuadFunc) Func(x []float64) float64 {
|
|
v := mat.NewVecDense(len(x), x)
|
|
var tmp mat.VecDense
|
|
tmp.MulVec(q.a, v)
|
|
return 0.5*mat.Dot(&tmp, v) + mat.Dot(q.b, v) + q.c
|
|
}
|
|
|
|
func (q QuadFunc) Grad(grad, x []float64) {
|
|
var tmp mat.VecDense
|
|
v := mat.NewVecDense(len(x), x)
|
|
tmp.MulVec(q.a, v)
|
|
for i := range grad {
|
|
grad[i] = tmp.At(i, 0) + q.b.At(i, 0)
|
|
}
|
|
}
|
|
|
|
func (q QuadFunc) Hess(dst mat.MutableSymmetric, x []float64) {
|
|
n := len(x)
|
|
for i := 0; i < n; i++ {
|
|
for j := i; j < n; j++ {
|
|
dst.SetSym(i, j, q.a.At(i, j))
|
|
}
|
|
}
|
|
}
|