blob: 96350c64c3ab694ac9056ab06a0efcb9a2bb393a [file] [log] [blame]
// Copyright ©2016 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package distmat
import (
"math"
"math/rand"
"testing"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/mat"
)
func TestWishart(t *testing.T) {
for c, test := range []struct {
v *mat.SymDense
nu float64
xs []*mat.SymDense
lps []float64
}{
// Logprob data compared with scipy.
{
v: mat.NewSymDense(2, []float64{1, 0, 0, 1}),
nu: 4,
xs: []*mat.SymDense{
mat.NewSymDense(2, []float64{0.9, 0.1, 0.1, 0.9}),
},
lps: []float64{-4.2357432031863409},
},
{
v: mat.NewSymDense(2, []float64{0.8, -0.2, -0.2, 0.7}),
nu: 5,
xs: []*mat.SymDense{
mat.NewSymDense(2, []float64{0.9, 0.1, 0.1, 0.9}),
mat.NewSymDense(2, []float64{0.3, -0.1, -0.1, 0.7}),
},
lps: []float64{-4.2476495605333575, -4.9993285370378633},
},
{
v: mat.NewSymDense(3, []float64{0.8, 0.3, 0.1, 0.3, 0.7, -0.1, 0.1, -0.1, 7}),
nu: 5,
xs: []*mat.SymDense{
mat.NewSymDense(3, []float64{1, 0.2, -0.3, 0.2, 0.6, -0.2, -0.3, -0.2, 6}),
},
lps: []float64{-11.010982249229421},
},
} {
w, ok := NewWishart(test.v, test.nu, nil)
if !ok {
panic("bad test")
}
for i, x := range test.xs {
lp := w.LogProbSym(x)
var chol mat.Cholesky
ok := chol.Factorize(x)
if !ok {
panic("bad test")
}
lpc := w.LogProbSymChol(&chol)
if math.Abs(lp-lpc) > 1e-14 {
t.Errorf("Case %d, test %d: probability mismatch between chol and not", c, i)
}
if !floats.EqualWithinAbsOrRel(lp, test.lps[i], 1e-14, 1e-14) {
t.Errorf("Case %d, test %d: got %v, want %v", c, i, lp, test.lps[i])
}
}
ch := w.RandChol(nil)
w.RandChol(ch)
s := w.RandSym(nil)
w.RandSym(s)
}
}
func TestWishartRand(t *testing.T) {
for c, test := range []struct {
v *mat.SymDense
nu float64
samples int
tol float64
}{
{
v: mat.NewSymDense(2, []float64{0.8, -0.2, -0.2, 0.7}),
nu: 5,
samples: 30000,
tol: 3e-2,
},
{
v: mat.NewSymDense(3, []float64{0.8, 0.3, 0.1, 0.3, 0.7, -0.1, 0.1, -0.1, 7}),
nu: 5,
samples: 30000,
tol: 3e-2,
},
{
v: mat.NewSymDense(4, []float64{
0.8, 0.3, 0.1, -0.2,
0.3, 0.7, -0.1, 0.4,
0.1, -0.1, 7, 1,
-0.2, -0.1, 1, 6}),
nu: 6,
samples: 30000,
tol: 1e-1,
},
} {
rnd := rand.New(rand.NewSource(1))
dim := test.v.Symmetric()
w, ok := NewWishart(test.v, test.nu, rnd)
if !ok {
panic("bad test")
}
mean := mat.NewSymDense(dim, nil)
x := mat.NewSymDense(dim, nil)
for i := 0; i < test.samples; i++ {
w.RandSym(x)
x.ScaleSym(1/float64(test.samples), x)
mean.AddSym(mean, x)
}
trueMean := w.MeanSym(nil)
if !mat.EqualApprox(trueMean, mean, test.tol) {
t.Errorf("Case %d: Mismatch between estimated and true mean. Got\n%0.4v\nWant\n%0.4v\n", c, mat.Formatted(mean), mat.Formatted(trueMean))
}
}
}