blob: 8d75fc502f8e73739af458e42c7d6391b9d647c4 [file] [log] [blame]
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package distmv
import (
"math"
"testing"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat"
)
type prober interface {
Prob(x []float64) float64
LogProb(x []float64) float64
}
type probCase struct {
dist prober
loc []float64
logProb float64
}
func testProbability(t *testing.T, cases []probCase) {
for _, test := range cases {
logProb := test.dist.LogProb(test.loc)
if math.Abs(logProb-test.logProb) > 1e-14 {
t.Errorf("LogProb mismatch: want: %v, got: %v", test.logProb, logProb)
}
prob := test.dist.Prob(test.loc)
if math.Abs(prob-math.Exp(test.logProb)) > 1e-14 {
t.Errorf("Prob mismatch: want: %v, got: %v", math.Exp(test.logProb), prob)
}
}
}
func generateSamples(x *mat.Dense, r Rander) {
n, _ := x.Dims()
for i := 0; i < n; i++ {
r.Rand(x.RawRowView(i))
}
}
type Meaner interface {
Mean([]float64) []float64
}
func checkMean(t *testing.T, cas int, x *mat.Dense, m Meaner, tol float64) {
mean := m.Mean(nil)
// Check that the answer is identical when using nil or non-nil.
mean2 := make([]float64, len(mean))
m.Mean(mean2)
if !floats.Equal(mean, mean2) {
t.Errorf("Mean mismatch when providing nil and slice. Case %v", cas)
}
// Check that the mean matches the samples.
r, _ := x.Dims()
col := make([]float64, r)
meanEst := make([]float64, len(mean))
for i := range meanEst {
meanEst[i] = stat.Mean(mat.Col(col, i, x), nil)
}
if !floats.EqualApprox(mean, meanEst, tol) {
t.Errorf("Returned mean and sample mean mismatch. Case %v. Empirical %v, returned %v", cas, meanEst, mean)
}
}
type Cover interface {
CovarianceMatrix(*mat.SymDense)
}
func checkCov(t *testing.T, cas int, x *mat.Dense, c Cover, tol float64) {
var cov mat.SymDense
c.CovarianceMatrix(&cov)
n := cov.Symmetric()
cov2 := mat.NewSymDense(n, nil)
c.CovarianceMatrix(cov2)
if !mat.Equal(&cov, cov2) {
t.Errorf("Cov mismatch when providing nil and matrix. Case %v", cas)
}
var cov3 mat.SymDense
c.CovarianceMatrix(&cov3)
if !mat.Equal(&cov, &cov3) {
t.Errorf("Cov mismatch when providing zero matrix. Case %v", cas)
}
// Check that the covariance matrix matches the samples
var covEst mat.SymDense
stat.CovarianceMatrix(&covEst, x, nil)
if !mat.EqualApprox(&covEst, &cov, tol) {
t.Errorf("Return cov and sample cov mismatch. Cas %v.\nGot:\n%0.4v\nWant:\n%0.4v", cas, mat.Formatted(&cov), mat.Formatted(&covEst))
}
}