blob: 324843c28f28c7e093aeedc8787410aae2c05ffc [file] [log] [blame]
// Copyright ©2016 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package distmv
import (
"math"
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/floats/scalar"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat"
)
func TestStudentTProbs(t *testing.T) {
src := rand.New(rand.NewSource(1))
for _, test := range []struct {
nu float64
mu []float64
sigma *mat.SymDense
x [][]float64
probs []float64
}{
{
nu: 3,
mu: []float64{0, 0},
sigma: mat.NewSymDense(2, []float64{1, 0, 0, 1}),
x: [][]float64{
{0, 0},
{1, -1},
{3, 4},
{-1, -2},
},
// Outputs compared with WolframAlpha.
probs: []float64{
0.159154943091895335768883,
0.0443811199724279860006777747927,
0.0005980371870904696541052658,
0.01370560783418571283428283,
},
},
{
nu: 4,
mu: []float64{2, -3},
sigma: mat.NewSymDense(2, []float64{8, -1, -1, 5}),
x: [][]float64{
{0, 0},
{1, -1},
{3, 4},
{-1, -2},
{2, -3},
},
// Outputs compared with WolframAlpha.
probs: []float64{
0.007360810111491788657953608191001,
0.0143309905845607117740440592999,
0.0005307774290578041397794096037035009801668903,
0.0115657422475668739943625904793879,
0.0254851872062589062995305736215,
},
},
} {
s, ok := NewStudentsT(test.mu, test.sigma, test.nu, src)
if !ok {
t.Fatal("bad test")
}
for i, x := range test.x {
xcpy := make([]float64, len(x))
copy(xcpy, x)
p := s.Prob(x)
if !floats.Same(x, xcpy) {
t.Errorf("X modified during call to prob, %v, %v", x, xcpy)
}
if !scalar.EqualWithinAbsOrRel(p, test.probs[i], 1e-10, 1e-10) {
t.Errorf("Probability mismatch. X = %v. Got %v, want %v.", x, p, test.probs[i])
}
}
}
}
func TestStudentsTRand(t *testing.T) {
src := rand.New(rand.NewSource(1))
for cas, test := range []struct {
mean []float64
cov *mat.SymDense
nu float64
tolcov float64
}{
{
mean: []float64{0, 0},
cov: mat.NewSymDense(2, []float64{1, 0, 0, 1}),
nu: 4,
tolcov: 1e-2,
},
{
mean: []float64{3, 4},
cov: mat.NewSymDense(2, []float64{5, 1.2, 1.2, 6}),
nu: 8,
tolcov: 1e-2,
},
{
mean: []float64{3, 4, -2},
cov: mat.NewSymDense(3, []float64{5, 1.2, -0.8, 1.2, 6, 0.4, -0.8, 0.4, 2}),
nu: 8,
tolcov: 1e-2,
},
} {
s, ok := NewStudentsT(test.mean, test.cov, test.nu, src)
if !ok {
t.Fatal("bad test")
}
const nSamples = 1e6
dim := len(test.mean)
samps := mat.NewDense(nSamples, dim, nil)
for i := 0; i < nSamples; i++ {
s.Rand(samps.RawRowView(i))
}
estMean := make([]float64, dim)
for i := range estMean {
estMean[i] = stat.Mean(mat.Col(nil, i, samps), nil)
}
mean := s.Mean(nil)
if !floats.EqualApprox(estMean, mean, 1e-2) {
t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean)
}
var cov, estCov mat.SymDense
s.CovarianceMatrix(&cov)
stat.CovarianceMatrix(&estCov, samps, nil)
if !mat.EqualApprox(&estCov, &cov, test.tolcov) {
t.Errorf("Case %d: Cov mismatch: want: %v, got %v", cas, &cov, &estCov)
}
}
}
func TestStudentsTConditional(t *testing.T) {
src := rand.New(rand.NewSource(1))
for _, test := range []struct {
mean []float64
cov *mat.SymDense
nu float64
idx []int
value []float64
tolcov float64
}{
{
mean: []float64{3, 4, -2},
cov: mat.NewSymDense(3, []float64{5, 1.2, -0.8, 1.2, 6, 0.4, -0.8, 0.4, 2}),
nu: 8,
idx: []int{0},
value: []float64{6},
tolcov: 1e-2,
},
} {
s, ok := NewStudentsT(test.mean, test.cov, test.nu, src)
if !ok {
t.Fatal("bad test")
}
sUp, ok := s.ConditionStudentsT(test.idx, test.value, src)
if !ok {
t.Error("unexpected failure of ConditionStudentsT")
}
// Compute the other values by hand the inefficient way to compare
newNu := test.nu + float64(len(test.idx))
if newNu != sUp.nu {
t.Errorf("Updated nu mismatch. Got %v, want %v", s.nu, newNu)
}
dim := len(test.mean)
unob := findUnob(test.idx, dim)
ob := test.idx
muUnob := make([]float64, len(unob))
for i, v := range unob {
muUnob[i] = test.mean[v]
}
muOb := make([]float64, len(ob))
for i, v := range ob {
muOb[i] = test.mean[v]
}
var sig11, sig22 mat.SymDense
sig11.SubsetSym(&s.sigma, unob)
sig22.SubsetSym(&s.sigma, ob)
sig12 := mat.NewDense(len(unob), len(ob), nil)
for i := range unob {
for j := range ob {
sig12.Set(i, j, s.sigma.At(unob[i], ob[j]))
}
}
shift := make([]float64, len(ob))
copy(shift, test.value)
floats.Sub(shift, muOb)
newMu := make([]float64, len(muUnob))
newMuVec := mat.NewVecDense(len(muUnob), newMu)
shiftVec := mat.NewVecDense(len(shift), shift)
var tmp mat.VecDense
err := tmp.SolveVec(&sig22, shiftVec)
if err != nil {
t.Errorf("unexpected error from vector solve: %v", err)
}
newMuVec.MulVec(sig12, &tmp)
floats.Add(newMu, muUnob)
if !floats.EqualApprox(newMu, sUp.mu, 1e-10) {
t.Errorf("Mu mismatch. Got %v, want %v", sUp.mu, newMu)
}
var tmp2 mat.Dense
err = tmp2.Solve(&sig22, sig12.T())
if err != nil {
t.Errorf("unexpected error from dense solve: %v", err)
}
var tmp3 mat.Dense
tmp3.Mul(sig12, &tmp2)
tmp3.Sub(&sig11, &tmp3)
dot := mat.Dot(shiftVec, &tmp)
tmp3.Scale((test.nu+dot)/(test.nu+float64(len(ob))), &tmp3)
if !mat.EqualApprox(&tmp3, &sUp.sigma, 1e-10) {
t.Errorf("Sigma mismatch")
}
}
}
func TestStudentsTMarginalSingle(t *testing.T) {
for _, test := range []struct {
mu []float64
sigma *mat.SymDense
nu float64
}{
{
mu: []float64{2, 3, 4},
sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
nu: 5,
},
{
mu: []float64{2, 3, 4, 5},
sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
nu: 6,
},
} {
studentst, ok := NewStudentsT(test.mu, test.sigma, test.nu, nil)
if !ok {
t.Fatalf("Bad test, covariance matrix not positive definite")
}
for i, mean := range test.mu {
st := studentst.MarginalStudentsTSingle(i, nil)
if st.Mean() != mean {
t.Errorf("Mean mismatch nil Sigma, idx %v: want %v, got %v.", i, mean, st.Mean())
}
std := math.Sqrt(test.sigma.At(i, i))
if math.Abs(st.Sigma-std) > 1e-14 {
t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, st.StdDev())
}
if st.Nu != test.nu {
t.Errorf("Nu mismatch nil Sigma, idx %v: want %v, got %v ", i, test.nu, st.Nu)
}
}
}
}