blob: 43bcb6878fc1308b7bbc21fe07016e32d1188cfc [file] [log] [blame]
// Copyright ©2014 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package distuv
import (
"fmt"
"math"
"testing"
"gonum.org/v1/gonum/diff/fd"
"gonum.org/v1/gonum/floats"
)
type univariateProbPoint struct {
loc float64
logProb float64
cumProb float64
prob float64
}
type UniProbDist interface {
Prob(float64) float64
CDF(float64) float64
LogProb(float64) float64
Quantile(float64) float64
Survival(float64) float64
}
func absEq(a, b float64) bool {
return absEqTol(a, b, 1e-14)
}
func absEqTol(a, b, tol float64) bool {
if math.IsNaN(a) || math.IsNaN(b) {
// NaN is not equal to anything.
return false
}
// This is expressed as the inverse to catch the
// case a = Inf and b = Inf of the same sign.
return !(math.Abs(a-b) > tol)
}
// TODO: Implement a better test for Quantile
func testDistributionProbs(t *testing.T, dist UniProbDist, name string, pts []univariateProbPoint) {
for _, pt := range pts {
logProb := dist.LogProb(pt.loc)
if !absEq(logProb, pt.logProb) {
t.Errorf("Log probability doesnt match for "+name+" at %v. Expected %v. Found %v", pt.loc, pt.logProb, logProb)
}
prob := dist.Prob(pt.loc)
if !absEq(prob, pt.prob) {
t.Errorf("Probability doesn't match for "+name+" at %v. Expected %v. Found %v", pt.loc, pt.prob, prob)
}
cumProb := dist.CDF(pt.loc)
if !absEq(cumProb, pt.cumProb) {
t.Errorf("Cumulative Probability doesn't match for "+name+". Expected %v. Found %v", pt.cumProb, cumProb)
}
if !absEq(dist.Survival(pt.loc), 1-pt.cumProb) {
t.Errorf("Survival doesn't match for %v. Expected %v, Found %v", name, 1-pt.cumProb, dist.Survival(pt.loc))
}
if pt.prob != 0 {
if math.Abs(dist.Quantile(pt.cumProb)-pt.loc) > 1e-4 {
fmt.Println("true =", pt.loc)
fmt.Println("calculated=", dist.Quantile(pt.cumProb))
t.Errorf("Quantile doesn't match for "+name+", loc = %v", pt.loc)
}
}
}
}
type ConjugateUpdater interface {
NumParameters() int
parameters([]Parameter) []Parameter
NumSuffStat() int
SuffStat([]float64, []float64, []float64) float64
ConjugateUpdate([]float64, float64, []float64)
Rand() float64
}
func testConjugateUpdate(t *testing.T, newFittable func() ConjugateUpdater) {
for i, test := range []struct {
samps []float64
weights []float64
}{
{
samps: randn(newFittable(), 10),
weights: nil,
},
{
samps: randn(newFittable(), 10),
weights: ones(10),
},
{
samps: randn(newFittable(), 10),
weights: randn(&Exponential{Rate: 1}, 10),
},
} {
// ensure that conjugate produces the same result both incrementally and all at once
incDist := newFittable()
stats := make([]float64, incDist.NumSuffStat())
prior := make([]float64, incDist.NumParameters())
for j := range test.samps {
var incWeights, allWeights []float64
if test.weights != nil {
incWeights = test.weights[j : j+1]
allWeights = test.weights[0 : j+1]
}
nsInc := incDist.SuffStat(stats, test.samps[j:j+1], incWeights)
incDist.ConjugateUpdate(stats, nsInc, prior)
allDist := newFittable()
nsAll := allDist.SuffStat(stats, test.samps[0:j+1], allWeights)
allDist.ConjugateUpdate(stats, nsAll, make([]float64, allDist.NumParameters()))
if !parametersEqual(incDist.parameters(nil), allDist.parameters(nil), 1e-12) {
t.Errorf("prior doesn't match after incremental update for (%d, %d). Incremental is %v, all at once is %v", i, j, incDist, allDist)
}
if test.weights == nil {
onesDist := newFittable()
nsOnes := onesDist.SuffStat(stats, test.samps[0:j+1], ones(j+1))
onesDist.ConjugateUpdate(stats, nsOnes, make([]float64, onesDist.NumParameters()))
if !parametersEqual(onesDist.parameters(nil), incDist.parameters(nil), 1e-14) {
t.Errorf("nil and uniform weighted prior doesn't match for incremental update for (%d, %d). Uniform weighted is %v, nil is %v", i, j, onesDist, incDist)
}
if !parametersEqual(onesDist.parameters(nil), allDist.parameters(nil), 1e-14) {
t.Errorf("nil and uniform weighted prior doesn't match for all at once update for (%d, %d). Uniform weighted is %v, nil is %v", i, j, onesDist, incDist)
}
}
}
}
testSuffStatPanics(t, newFittable)
testConjugateUpdatePanics(t, newFittable)
}
func testSuffStatPanics(t *testing.T, newFittable func() ConjugateUpdater) {
dist := newFittable()
sample := randn(dist, 10)
if !panics(func() { dist.SuffStat(make([]float64, dist.NumSuffStat()), sample, make([]float64, len(sample)+1)) }) {
t.Errorf("Expected panic for mismatch between samples and weights lengths")
}
if !panics(func() { dist.SuffStat(make([]float64, dist.NumSuffStat()+1), sample, nil) }) {
t.Errorf("Expected panic for wrong sufficient statistic length")
}
}
func testConjugateUpdatePanics(t *testing.T, newFittable func() ConjugateUpdater) {
dist := newFittable()
if !panics(func() {
dist.ConjugateUpdate(make([]float64, dist.NumSuffStat()+1), 100, make([]float64, dist.NumParameters()))
}) {
t.Errorf("Expected panic for wrong sufficient statistic length")
}
if !panics(func() {
dist.ConjugateUpdate(make([]float64, dist.NumSuffStat()), 100, make([]float64, dist.NumParameters()+1))
}) {
t.Errorf("Expected panic for wrong prior strength length")
}
}
// randn generates a specified number of random samples
func randn(dist Rander, n int) []float64 {
x := make([]float64, n)
for i := range x {
x[i] = dist.Rand()
}
return x
}
func ones(n int) []float64 {
x := make([]float64, n)
for i := range x {
x[i] = 1
}
return x
}
func parametersEqual(p1, p2 []Parameter, tol float64) bool {
for i, p := range p1 {
if p.Name != p2[i].Name {
return false
}
if math.Abs(p.Value-p2[i].Value) > tol {
return false
}
}
return true
}
type derivParamTester interface {
LogProb(x float64) float64
Score(deriv []float64, x float64) []float64
ScoreInput(x float64) float64
Quantile(p float64) float64
NumParameters() int
parameters([]Parameter) []Parameter
setParameters([]Parameter)
}
func testDerivParam(t *testing.T, d derivParamTester) {
// Tests that the derivative matches for a number of different quantiles
// along the distribution.
nTest := 10
quantiles := make([]float64, nTest)
floats.Span(quantiles, 0.1, 0.9)
scoreInPlace := make([]float64, d.NumParameters())
fdDerivParam := make([]float64, d.NumParameters())
if !panics(func() { d.Score(make([]float64, d.NumParameters()+1), 0) }) {
t.Errorf("Expected panic for wrong derivative slice length")
}
if !panics(func() { d.parameters(make([]Parameter, d.NumParameters()+1)) }) {
t.Errorf("Expected panic for wrong parameter slice length")
}
initParams := d.parameters(nil)
tooLongParams := make([]Parameter, len(initParams)+1)
copy(tooLongParams, initParams)
if !panics(func() { d.setParameters(tooLongParams) }) {
t.Errorf("Expected panic for wrong parameter slice length")
}
badNameParams := make([]Parameter, len(initParams))
copy(badNameParams, initParams)
const badName = "__badName__"
for i := 0; i < len(initParams); i++ {
badNameParams[i].Name = badName
if !panics(func() { d.setParameters(badNameParams) }) {
t.Errorf("Expected panic for wrong %d-th parameter name", i)
}
badNameParams[i].Name = initParams[i].Name
}
init := make([]float64, d.NumParameters())
for i, v := range initParams {
init[i] = v.Value
}
for _, v := range quantiles {
d.setParameters(initParams)
x := d.Quantile(v)
score := d.Score(scoreInPlace, x)
if &score[0] != &scoreInPlace[0] {
t.Errorf("Returned a different derivative slice than passed in. Got %v, want %v", score, scoreInPlace)
}
logProbParams := func(p []float64) float64 {
params := d.parameters(nil)
for i, v := range p {
params[i].Value = v
}
d.setParameters(params)
return d.LogProb(x)
}
fd.Gradient(fdDerivParam, logProbParams, init, nil)
if !floats.EqualApprox(scoreInPlace, fdDerivParam, 1e-6) {
t.Errorf("Score mismatch at x = %g. Want %v, got %v", x, fdDerivParam, scoreInPlace)
}
d.setParameters(initParams)
score2 := d.Score(nil, x)
if !floats.EqualApprox(score2, scoreInPlace, 1e-14) {
t.Errorf("Score mismatch when input nil Want %v, got %v", score2, scoreInPlace)
}
logProbInput := func(x2 float64) float64 {
return d.LogProb(x2)
}
scoreInput := d.ScoreInput(x)
fdDerivInput := fd.Derivative(logProbInput, x, nil)
if !absEqTol(scoreInput, fdDerivInput, 1e-6) {
t.Errorf("ScoreInput mismatch at x = %g. Want %v, got %v", x, fdDerivInput, scoreInput)
}
}
}