// Copyright ©2017 The gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package fd import ( "gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/mat" ) // ConstFunc is a constant function returning the value held by the type. type ConstFunc float64 func (c ConstFunc) Func(x []float64) float64 { return float64(c) } func (c ConstFunc) Grad(grad, x []float64) { for i := range grad { grad[i] = 0 } } func (c ConstFunc) Hess(dst mat.MutableSymmetric, x []float64) { n := len(x) for i := 0; i < n; i++ { for j := i; j < n; j++ { dst.SetSym(i, j, 0) } } } // LinearFunc is a linear function returning w*x+c. type LinearFunc struct { w []float64 c float64 } func (l LinearFunc) Func(x []float64) float64 { return floats.Dot(l.w, x) + l.c } func (l LinearFunc) Grad(grad, x []float64) { copy(grad, l.w) } func (l LinearFunc) Hess(dst mat.MutableSymmetric, x []float64) { n := len(x) for i := 0; i < n; i++ { for j := i; j < n; j++ { dst.SetSym(i, j, 0) } } } // QuadFunc is a quadratic function returning 0.5*x'*a*x + b*x + c. type QuadFunc struct { a *mat.SymDense b *mat.VecDense c float64 } func (q QuadFunc) Func(x []float64) float64 { v := mat.NewVecDense(len(x), x) var tmp mat.VecDense tmp.MulVec(q.a, v) return 0.5*mat.Dot(&tmp, v) + mat.Dot(q.b, v) + q.c } func (q QuadFunc) Grad(grad, x []float64) { var tmp mat.VecDense v := mat.NewVecDense(len(x), x) tmp.MulVec(q.a, v) for i := range grad { grad[i] = tmp.At(i, 0) + q.b.At(i, 0) } } func (q QuadFunc) Hess(dst mat.MutableSymmetric, x []float64) { n := len(x) for i := 0; i < n; i++ { for j := i; j < n; j++ { dst.SetSym(i, j, q.a.At(i, j)) } } }