blob: d6e1691cb61eaed047a08e8f41f46e944f8e7f97 [file] [log] [blame]
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testlapack
import (
"testing"
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/floats"
)
type Dpotf2er interface {
Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
}
func Dpotf2Test(t *testing.T, impl Dpotf2er) {
for _, test := range []struct {
a [][]float64
pos bool
U [][]float64
}{
{
a: [][]float64{
{23, 37, 34, 32},
{108, 71, 48, 48},
{109, 109, 67, 58},
{106, 107, 106, 63},
},
pos: true,
U: [][]float64{
{4.795831523312719, 7.715033320111766, 7.089490077940543, 6.672461249826393},
{0, 3.387958215439679, -1.976308959006481, -1.026654004678691},
{0, 0, 3.582364210034111, 2.419258947036024},
{0, 0, 0, 3.401680257083044},
},
},
{
a: [][]float64{
{8, 2},
{2, 4},
},
pos: true,
U: [][]float64{
{2.82842712474619, 0.707106781186547},
{0, 1.870828693386971},
},
},
} {
testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0]), blas.Upper)
testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0])+5, blas.Upper)
aT := transpose(test.a)
L := transpose(test.U)
testDpotf2(t, impl, test.pos, aT, L, len(test.a[0]), blas.Lower)
testDpotf2(t, impl, test.pos, aT, L, len(test.a[0])+5, blas.Lower)
}
}
func testDpotf2(t *testing.T, impl Dpotf2er, testPos bool, a, ans [][]float64, stride int, ul blas.Uplo) {
aFlat := flattenTri(a, stride, ul)
ansFlat := flattenTri(ans, stride, ul)
pos := impl.Dpotf2(ul, len(a[0]), aFlat, stride)
if pos != testPos {
t.Errorf("Positive definite mismatch: Want %v, Got %v", testPos, pos)
return
}
if testPos && !floats.EqualApprox(ansFlat, aFlat, 1e-14) {
t.Errorf("Result mismatch: Want %v, Got %v", ansFlat, aFlat)
}
}
// flattenTri with a certain stride. stride must be >= dimension. Puts repeatable
// nonce values in non-accessed places
func flattenTri(a [][]float64, stride int, ul blas.Uplo) []float64 {
m := len(a)
n := len(a[0])
if stride < n {
panic("bad stride")
}
upper := ul == blas.Upper
v := make([]float64, m*stride)
count := 1000.0
for i := 0; i < m; i++ {
for j := 0; j < stride; j++ {
if j >= n || (upper && j < i) || (!upper && j > i) {
// not accessed, so give a unique crazy number
v[i*stride+j] = count
count++
continue
}
v[i*stride+j] = a[i][j]
}
}
return v
}
func transpose(a [][]float64) [][]float64 {
m := len(a)
n := len(a[0])
if m != n {
panic("not square")
}
aNew := make([][]float64, m)
for i := 0; i < m; i++ {
aNew[i] = make([]float64, n)
}
for i := 0; i < m; i++ {
if len(a[i]) != n {
panic("bad n size")
}
for j := 0; j < n; j++ {
aNew[j][i] = a[i][j]
}
}
return aNew
}