blob: 6a676501cd26e1b58cae562c3f1fbcd5c2207018 [file] [log] [blame]
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package distmv
import (
"math"
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/diff/fd"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat"
)
func TestNormProbs(t *testing.T) {
dist1, ok := NewNormal([]float64{0, 0}, mat.NewSymDense(2, []float64{1, 0, 0, 1}), nil)
if !ok {
t.Errorf("bad test")
}
dist2, ok := NewNormal([]float64{6, 7}, mat.NewSymDense(2, []float64{8, 2, 0, 4}), nil)
if !ok {
t.Errorf("bad test")
}
testProbability(t, []probCase{
{
dist: dist1,
loc: []float64{0, 0},
logProb: -1.837877066409345,
},
{
dist: dist2,
loc: []float64{6, 7},
logProb: -3.503979321496947,
},
{
dist: dist2,
loc: []float64{1, 2},
logProb: -7.075407892925519,
},
})
}
func TestNewNormalChol(t *testing.T) {
for _, test := range []struct {
mean []float64
cov *mat.SymDense
}{
{
mean: []float64{2, 3},
cov: mat.NewSymDense(2, []float64{1, 0.1, 0.1, 1}),
},
} {
var chol mat.Cholesky
ok := chol.Factorize(test.cov)
if !ok {
panic("bad test")
}
n := NewNormalChol(test.mean, &chol, nil)
// Generate a random number and calculate probability to ensure things
// have been set properly. See issue #426.
x := n.Rand(nil)
_ = n.Prob(x)
}
}
func TestNormRand(t *testing.T) {
for _, test := range []struct {
mean []float64
cov []float64
}{
{
mean: []float64{0, 0},
cov: []float64{
1, 0,
0, 1,
},
},
{
mean: []float64{0, 0},
cov: []float64{
1, 0.9,
0.9, 1,
},
},
{
mean: []float64{6, 7},
cov: []float64{
5, 0.9,
0.9, 2,
},
},
} {
dim := len(test.mean)
cov := mat.NewSymDense(dim, test.cov)
n, ok := NewNormal(test.mean, cov, nil)
if !ok {
t.Errorf("bad covariance matrix")
}
nSamples := 1000000
samps := mat.NewDense(nSamples, dim, nil)
for i := 0; i < nSamples; i++ {
n.Rand(samps.RawRowView(i))
}
estMean := make([]float64, dim)
for i := range estMean {
estMean[i] = stat.Mean(mat.Col(nil, i, samps), nil)
}
if !floats.EqualApprox(estMean, test.mean, 1e-2) {
t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean)
}
var estCov mat.SymDense
stat.CovarianceMatrix(&estCov, samps, nil)
if !mat.EqualApprox(&estCov, cov, 1e-2) {
t.Errorf("Cov mismatch: want: %v, got %v", cov, &estCov)
}
}
}
func TestNormalQuantile(t *testing.T) {
for _, test := range []struct {
mean []float64
cov []float64
}{
{
mean: []float64{6, 7},
cov: []float64{
5, 0.9,
0.9, 2,
},
},
} {
dim := len(test.mean)
cov := mat.NewSymDense(dim, test.cov)
n, ok := NewNormal(test.mean, cov, nil)
if !ok {
t.Errorf("bad covariance matrix")
}
nSamples := 1000000
rnd := rand.New(rand.NewSource(1))
samps := mat.NewDense(nSamples, dim, nil)
tmp := make([]float64, dim)
for i := 0; i < nSamples; i++ {
for j := range tmp {
tmp[j] = rnd.Float64()
}
n.Quantile(samps.RawRowView(i), tmp)
}
estMean := make([]float64, dim)
for i := range estMean {
estMean[i] = stat.Mean(mat.Col(nil, i, samps), nil)
}
if !floats.EqualApprox(estMean, test.mean, 1e-2) {
t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean)
}
var estCov mat.SymDense
stat.CovarianceMatrix(&estCov, samps, nil)
if !mat.EqualApprox(&estCov, cov, 1e-2) {
t.Errorf("Cov mismatch: want: %v, got %v", cov, &estCov)
}
}
}
func TestConditionNormal(t *testing.T) {
// Uncorrelated values shouldn't influence the updated values.
for _, test := range []struct {
mu []float64
sigma *mat.SymDense
observed []int
values []float64
newMu []float64
newSigma *mat.SymDense
}{
{
mu: []float64{2, 3},
sigma: mat.NewSymDense(2, []float64{2, 0, 0, 5}),
observed: []int{0},
values: []float64{10},
newMu: []float64{3},
newSigma: mat.NewSymDense(1, []float64{5}),
},
{
mu: []float64{2, 3},
sigma: mat.NewSymDense(2, []float64{2, 0, 0, 5}),
observed: []int{1},
values: []float64{10},
newMu: []float64{2},
newSigma: mat.NewSymDense(1, []float64{2}),
},
{
mu: []float64{2, 3, 4},
sigma: mat.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
observed: []int{1},
values: []float64{10},
newMu: []float64{2, 4},
newSigma: mat.NewSymDense(2, []float64{2, 0, 0, 10}),
},
{
mu: []float64{2, 3, 4},
sigma: mat.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}),
observed: []int{0, 1},
values: []float64{10, 15},
newMu: []float64{4},
newSigma: mat.NewSymDense(1, []float64{10}),
},
{
mu: []float64{2, 3, 4, 5},
sigma: mat.NewSymDense(4, []float64{2, 0.5, 0, 0, 0.5, 5, 0, 0, 0, 0, 10, 2, 0, 0, 2, 3}),
observed: []int{0, 1},
values: []float64{10, 15},
newMu: []float64{4, 5},
newSigma: mat.NewSymDense(2, []float64{10, 2, 2, 3}),
},
} {
normal, ok := NewNormal(test.mu, test.sigma, nil)
if !ok {
t.Fatalf("Bad test, original sigma not positive definite")
}
newNormal, ok := normal.ConditionNormal(test.observed, test.values, nil)
if !ok {
t.Fatalf("Bad test, update failure")
}
if !floats.EqualApprox(test.newMu, newNormal.mu, 1e-12) {
t.Errorf("Updated mean mismatch. Want %v, got %v.", test.newMu, newNormal.mu)
}
var sigma mat.SymDense
newNormal.chol.ToSym(&sigma)
if !mat.EqualApprox(test.newSigma, &sigma, 1e-12) {
t.Errorf("Updated sigma mismatch\n.Want:\n% v\nGot:\n% v\n", test.newSigma, sigma)
}
}
// Test bivariate case where the update rule is analytic
for _, test := range []struct {
mu []float64
std []float64
rho float64
value float64
}{
{
mu: []float64{2, 3},
std: []float64{3, 5},
rho: 0.9,
value: 1000,
},
{
mu: []float64{2, 3},
std: []float64{3, 5},
rho: -0.9,
value: 1000,
},
} {
std := test.std
rho := test.rho
sigma := mat.NewSymDense(2, []float64{std[0] * std[0], std[0] * std[1] * rho, std[0] * std[1] * rho, std[1] * std[1]})
normal, ok := NewNormal(test.mu, sigma, nil)
if !ok {
t.Fatalf("Bad test, original sigma not positive definite")
}
newNormal, ok := normal.ConditionNormal([]int{1}, []float64{test.value}, nil)
if !ok {
t.Fatalf("Bad test, update failed")
}
var newSigma mat.SymDense
newNormal.chol.ToSym(&newSigma)
trueMean := test.mu[0] + rho*(std[0]/std[1])*(test.value-test.mu[1])
if math.Abs(trueMean-newNormal.mu[0]) > 1e-14 {
t.Errorf("Mean mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
}
trueVar := (1 - rho*rho) * std[0] * std[0]
if math.Abs(trueVar-newSigma.At(0, 0)) > 1e-14 {
t.Errorf("Std mismatch. Want %v, got %v", trueMean, newNormal.mu[0])
}
}
// Test via sampling.
for _, test := range []struct {
mu []float64
sigma *mat.SymDense
observed []int
unobserved []int
value []float64
}{
// The indices in unobserved must be in ascending order for this test.
{
mu: []float64{2, 3, 4},
sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
observed: []int{0},
unobserved: []int{1, 2},
value: []float64{1.9},
},
{
mu: []float64{2, 3, 4, 5},
sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
observed: []int{0, 3},
unobserved: []int{1, 2},
value: []float64{1.9, 2.9},
},
} {
totalSamp := 4000000
var nSamp int
samples := mat.NewDense(totalSamp, len(test.mu), nil)
normal, ok := NewNormal(test.mu, test.sigma, nil)
if !ok {
t.Errorf("bad test")
}
sample := make([]float64, len(test.mu))
for i := 0; i < totalSamp; i++ {
normal.Rand(sample)
isClose := true
for i, v := range test.observed {
if math.Abs(sample[v]-test.value[i]) > 1e-1 {
isClose = false
break
}
}
if isClose {
samples.SetRow(nSamp, sample)
nSamp++
}
}
if nSamp < 100 {
t.Errorf("bad test, not enough samples")
continue
}
samples = samples.Slice(0, nSamp, 0, len(test.mu)).(*mat.Dense)
// Compute mean and covariance matrix.
estMean := make([]float64, len(test.mu))
for i := range estMean {
estMean[i] = stat.Mean(mat.Col(nil, i, samples), nil)
}
var estCov mat.SymDense
stat.CovarianceMatrix(&estCov, samples, nil)
// Compute update rule.
newNormal, ok := normal.ConditionNormal(test.observed, test.value, nil)
if !ok {
t.Fatalf("Bad test, update failure")
}
var subEstMean []float64
for _, v := range test.unobserved {
subEstMean = append(subEstMean, estMean[v])
}
subEstCov := mat.NewSymDense(len(test.unobserved), nil)
for i := 0; i < len(test.unobserved); i++ {
for j := i; j < len(test.unobserved); j++ {
subEstCov.SetSym(i, j, estCov.At(test.unobserved[i], test.unobserved[j]))
}
}
for i, v := range subEstMean {
if math.Abs(newNormal.mu[i]-v) > 5e-2 {
t.Errorf("Mean mismatch. Want %v, got %v.", newNormal.mu[i], v)
}
}
var sigma mat.SymDense
newNormal.chol.ToSym(&sigma)
if !mat.EqualApprox(&sigma, subEstCov, 1e-1) {
t.Errorf("Covariance mismatch. Want:\n%0.8v\nGot:\n%0.8v\n", subEstCov, sigma)
}
}
}
func TestCovarianceMatrix(t *testing.T) {
for _, test := range []struct {
mu []float64
sigma *mat.SymDense
}{
{
mu: []float64{2, 3, 4},
sigma: mat.NewSymDense(3, []float64{1, 0.5, 3, 0.5, 8, -1, 3, -1, 15}),
},
} {
normal, ok := NewNormal(test.mu, test.sigma, nil)
if !ok {
t.Fatalf("Bad test, covariance matrix not positive definite")
}
var cov mat.SymDense
normal.CovarianceMatrix(&cov)
if !mat.EqualApprox(&cov, test.sigma, 1e-14) {
t.Errorf("Covariance mismatch with nil input")
}
dim := test.sigma.Symmetric()
cov = *mat.NewSymDense(dim, nil)
normal.CovarianceMatrix(&cov)
if !mat.EqualApprox(&cov, test.sigma, 1e-14) {
t.Errorf("Covariance mismatch with supplied input")
}
}
}
func TestMarginal(t *testing.T) {
for _, test := range []struct {
mu []float64
sigma *mat.SymDense
marginal []int
}{
{
mu: []float64{2, 3, 4},
sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
marginal: []int{0},
},
{
mu: []float64{2, 3, 4},
sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
marginal: []int{0, 2},
},
{
mu: []float64{2, 3, 4, 5},
sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
marginal: []int{0, 3},
},
} {
normal, ok := NewNormal(test.mu, test.sigma, nil)
if !ok {
t.Fatalf("Bad test, covariance matrix not positive definite")
}
marginal, ok := normal.MarginalNormal(test.marginal, nil)
if !ok {
t.Fatalf("Bad test, marginal matrix not positive definite")
}
dim := normal.Dim()
nSamples := 1000000
samps := mat.NewDense(nSamples, dim, nil)
for i := 0; i < nSamples; i++ {
normal.Rand(samps.RawRowView(i))
}
estMean := make([]float64, dim)
for i := range estMean {
estMean[i] = stat.Mean(mat.Col(nil, i, samps), nil)
}
for i, v := range test.marginal {
if math.Abs(marginal.mu[i]-estMean[v]) > 1e-2 {
t.Errorf("Mean mismatch: want: %v, got %v", estMean[v], marginal.mu[i])
}
}
var marginalCov mat.SymDense
marginal.CovarianceMatrix(&marginalCov)
var estCov mat.SymDense
stat.CovarianceMatrix(&estCov, samps, nil)
for i, v1 := range test.marginal {
for j, v2 := range test.marginal {
c := marginalCov.At(i, j)
ec := estCov.At(v1, v2)
if math.Abs(c-ec) > 5e-2 {
t.Errorf("Cov mismatch element i = %d, j = %d: want: %v, got %v", i, j, c, ec)
}
}
}
}
}
func TestMarginalSingle(t *testing.T) {
for _, test := range []struct {
mu []float64
sigma *mat.SymDense
}{
{
mu: []float64{2, 3, 4},
sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
},
{
mu: []float64{2, 3, 4, 5},
sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
},
} {
normal, ok := NewNormal(test.mu, test.sigma, nil)
if !ok {
t.Fatalf("Bad test, covariance matrix not positive definite")
}
for i, mean := range test.mu {
norm := normal.MarginalNormalSingle(i, nil)
if norm.Mean() != mean {
t.Errorf("Mean mismatch nil Sigma, idx %v: want %v, got %v.", i, mean, norm.Mean())
}
std := math.Sqrt(test.sigma.At(i, i))
if math.Abs(norm.StdDev()-std) > 1e-14 {
t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev())
}
}
}
// Test matching with TestMarginal.
rnd := rand.New(rand.NewSource(1))
for cas := 0; cas < 10; cas++ {
dim := rnd.Intn(10) + 1
mu := make([]float64, dim)
for i := range mu {
mu[i] = rnd.Float64()
}
x := make([]float64, dim*dim)
for i := range x {
x[i] = rnd.Float64()
}
matrix := mat.NewDense(dim, dim, x)
var sigma mat.SymDense
sigma.SymOuterK(1, matrix)
normal, ok := NewNormal(mu, &sigma, nil)
if !ok {
t.Fatal("bad test")
}
for i := 0; i < dim; i++ {
single := normal.MarginalNormalSingle(i, nil)
mult, ok := normal.MarginalNormal([]int{i}, nil)
if !ok {
t.Fatal("bad test")
}
if math.Abs(single.Mean()-mult.Mean(nil)[0]) > 1e-14 {
t.Errorf("Mean mismatch")
}
var cov mat.SymDense
mult.CovarianceMatrix(&cov)
if math.Abs(single.Variance()-cov.At(0, 0)) > 1e-14 {
t.Errorf("Variance mismatch")
}
}
}
}
func TestNormalScoreInput(t *testing.T) {
for cas, test := range []struct {
mu []float64
sigma *mat.SymDense
x []float64
}{
{
mu: []float64{2, 3, 4},
sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
x: []float64{1, 3.1, -2},
},
{
mu: []float64{2, 3, 4, 5},
sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
x: []float64{1, 3.1, -2, 5},
},
} {
normal, ok := NewNormal(test.mu, test.sigma, nil)
if !ok {
t.Fatalf("Bad test, covariance matrix not positive definite")
}
x := make([]float64, len(test.x))
copy(x, test.x)
score := normal.ScoreInput(nil, x)
if !floats.Equal(x, test.x) {
t.Errorf("x modified during call to ScoreInput")
}
scoreFD := fd.Gradient(nil, normal.LogProb, x, nil)
if !floats.EqualApprox(score, scoreFD, 1e-4) {
t.Errorf("Case %d: derivative mismatch. Got %v, want %v", cas, score, scoreFD)
}
}
}