blob: 1c1dd42e92cb049d934275f2937b2fd3bca72ac4 [file] [log] [blame]
// Copyright ©2017 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fd
import (
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/mat"
)
// ConstFunc is a constant function returning the value held by the type.
type ConstFunc float64
func (c ConstFunc) Func(x []float64) float64 {
return float64(c)
}
func (c ConstFunc) Grad(grad, x []float64) {
for i := range grad {
grad[i] = 0
}
}
func (c ConstFunc) Hess(dst mat.MutableSymmetric, x []float64) {
n := len(x)
for i := 0; i < n; i++ {
for j := i; j < n; j++ {
dst.SetSym(i, j, 0)
}
}
}
// LinearFunc is a linear function returning w*x+c.
type LinearFunc struct {
w []float64
c float64
}
func (l LinearFunc) Func(x []float64) float64 {
return floats.Dot(l.w, x) + l.c
}
func (l LinearFunc) Grad(grad, x []float64) {
copy(grad, l.w)
}
func (l LinearFunc) Hess(dst mat.MutableSymmetric, x []float64) {
n := len(x)
for i := 0; i < n; i++ {
for j := i; j < n; j++ {
dst.SetSym(i, j, 0)
}
}
}
// QuadFunc is a quadratic function returning 0.5*x'*a*x + b*x + c.
type QuadFunc struct {
a *mat.SymDense
b *mat.VecDense
c float64
}
func (q QuadFunc) Func(x []float64) float64 {
v := mat.NewVecDense(len(x), x)
var tmp mat.VecDense
tmp.MulVec(q.a, v)
return 0.5*mat.Dot(&tmp, v) + mat.Dot(q.b, v) + q.c
}
func (q QuadFunc) Grad(grad, x []float64) {
var tmp mat.VecDense
v := mat.NewVecDense(len(x), x)
tmp.MulVec(q.a, v)
for i := range grad {
grad[i] = tmp.At(i, 0) + q.b.At(i, 0)
}
}
func (q QuadFunc) Hess(dst mat.MutableSymmetric, x []float64) {
n := len(x)
for i := 0; i < n; i++ {
for j := i; j < n; j++ {
dst.SetSym(i, j, q.a.At(i, j))
}
}
}