blob: 8fdbd4e8b1e3d13d0bddb76b5988d5df0111a2fa [file] [log] [blame]
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package distmv
import (
type prober interface {
Prob(x []float64) float64
LogProb(x []float64) float64
type probCase struct {
dist prober
loc []float64
logProb float64
func testProbability(t *testing.T, cases []probCase) {
for _, test := range cases {
logProb := test.dist.LogProb(test.loc)
if math.Abs(logProb-test.logProb) > 1e-14 {
t.Errorf("LogProb mismatch: want: %v, got: %v", test.logProb, logProb)
prob := test.dist.Prob(test.loc)
if math.Abs(prob-math.Exp(test.logProb)) > 1e-14 {
t.Errorf("Prob mismatch: want: %v, got: %v", math.Exp(test.logProb), prob)
func generateSamples(x *mat.Dense, r Rander) {
n, _ := x.Dims()
for i := 0; i < n; i++ {
type Meaner interface {
Mean([]float64) []float64
func checkMean(t *testing.T, cas int, x *mat.Dense, m Meaner, tol float64) {
mean := m.Mean(nil)
// Check that the answer is identical when using nil or non-nil.
mean2 := make([]float64, len(mean))
if !floats.Equal(mean, mean2) {
t.Errorf("Mean mismatch when providing nil and slice. Case %v", cas)
// Check that the mean matches the samples.
r, _ := x.Dims()
col := make([]float64, r)
meanEst := make([]float64, len(mean))
for i := range meanEst {
meanEst[i] = stat.Mean(mat.Col(col, i, x), nil)
if !floats.EqualApprox(mean, meanEst, tol) {
t.Errorf("Returned mean and sample mean mismatch. Case %v. Empirical %v, returned %v", cas, meanEst, mean)
type Cover interface {
CovarianceMatrix(*mat.SymDense) *mat.SymDense
func checkCov(t *testing.T, cas int, x *mat.Dense, c Cover, tol float64) {
cov := c.CovarianceMatrix(nil)
n := cov.Symmetric()
cov2 := mat.NewSymDense(n, nil)
if !mat.Equal(cov, cov2) {
t.Errorf("Cov mismatch when providing nil and matrix. Case %v", cas)
var cov3 mat.SymDense
if !mat.Equal(cov, &cov3) {
t.Errorf("Cov mismatch when providing zero matrix. Case %v", cas)
// Check that the covariance matrix matches the samples
covEst := stat.CovarianceMatrix(nil, x, nil)
if !mat.EqualApprox(covEst, cov, tol) {
t.Errorf("Return cov and sample cov mismatch. Cas %v.\nGot:\n%0.4v\nWant:\n%0.4v", cas, mat.Formatted(cov), mat.Formatted(covEst))