blob: a59553b1beb353c2b161ac61d78fd34237e7d90f [file] [log] [blame]
// Copyright ©2016 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package distuv
import (
"math"
"sort"
"testing"
"golang.org/x/exp/rand"
)
func TestBernoulli(t *testing.T) {
t.Parallel()
src := rand.New(rand.NewSource(1))
for i, dist := range []Bernoulli{
{P: 0.5, Src: src},
{P: 0.9, Src: src},
{P: 0.2, Src: src},
{P: 0.0, Src: src},
{P: 1.0, Src: src},
} {
testBernoulli(t, dist, i)
testBernoulliCDF(t, dist)
testBernoulliSurvival(t, dist)
testBernoulliQuantile(t, dist)
if dist.P == 0 || dist.P == 1 {
entropy := dist.Entropy()
if entropy != 0 {
t.Errorf("Entropy of a Bernoulli distribution with P = %g is not zero, got: %g", dist.P, entropy)
}
}
if dist.NumParameters() != 1 {
t.Errorf("Wrong number of parameters")
}
for _, x := range []float64{-0.2, 0.5, 1.1} {
logP := dist.LogProb(x)
p := dist.Prob(x)
if !math.IsInf(logP, -1) {
t.Errorf("Log-probability for x = %g is not -Inf, got: %g", x, logP)
}
if p != 0 {
t.Errorf("Probability for x = %g is not 0, got: %g", x, p)
}
}
}
}
func testBernoulli(t *testing.T, dist Bernoulli, i int) {
const (
tol = 1e-2
n = 3e6
bins = 50
)
x := make([]float64, n)
generateSamples(x, dist)
sort.Float64s(x)
checkMean(t, i, x, dist, tol)
checkVarAndStd(t, i, x, dist, tol)
checkEntropy(t, i, x, dist, tol)
checkProbDiscrete(t, i, x, dist, tol)
if dist.P != 0 && dist.P != 1 {
// Sample kurtosis and skewness are going to be NaN for P = 0 or 1.
checkExKurtosis(t, i, x, dist, tol)
checkSkewness(t, i, x, dist, tol)
} else {
if !math.IsInf(dist.ExKurtosis(), 1) {
t.Errorf("Excess kurtosis for P == 0 or 1 is not +Inf")
}
skewness := dist.Skewness()
if dist.P == 0 {
if !math.IsInf(skewness, 1) {
t.Errorf("Skewness for P == 0 is not +Inf")
}
} else {
if !math.IsInf(skewness, -1) {
t.Errorf("Skewness for P == 1 is not -Inf")
}
}
}
if dist.P != 0.5 {
checkMedian(t, i, x, dist, tol)
} else if dist.Median() != 0.5 {
t.Errorf("Median for P == 0.5 is not 0.5")
}
}
func testBernoulliCDF(t *testing.T, dist Bernoulli) {
if dist.CDF(-0.000001) != 0 {
t.Errorf("Bernoulli CDF below zero is not zero")
}
if dist.CDF(0) != 1-dist.P {
t.Errorf("Bernoulli CDF at zero is not 1 - P(1)")
}
if dist.CDF(0.0001) != 1-dist.P {
t.Errorf("Bernoulli CDF between zero and one is not 1 - P(1)")
}
if dist.CDF(0.9999) != 1-dist.P {
t.Errorf("Bernoulli CDF between zero and one is not 1 - P(1)")
}
if dist.CDF(1) != 1 {
t.Errorf("Bernoulli CDF at one is not one")
}
if dist.CDF(1.00001) != 1 {
t.Errorf("Bernoulli CDF above one is not one")
}
}
func testBernoulliSurvival(t *testing.T, dist Bernoulli) {
if dist.Survival(-0.000001) != 1 {
t.Errorf("Bernoulli Survival below zero is not one")
}
if dist.Survival(0) != dist.P {
t.Errorf("Bernoulli Survival at zero is not P(1)")
}
if dist.Survival(0.0001) != dist.P {
t.Errorf("Bernoulli Survival between zero and one is not P(1)")
}
if dist.Survival(1) != 0 {
t.Errorf("Bernoulli Survival at one is not zero")
}
if dist.Survival(1.00001) != 0 {
t.Errorf("Bernoulli Survival above one is not zero")
}
}
func testBernoulliQuantile(t *testing.T, dist Bernoulli) {
if !panics(func() { dist.Quantile(-0.0001) }) {
t.Errorf("Expected panic with negative argument")
}
if !panics(func() { dist.Quantile(1.0001) }) {
t.Errorf("Expected panic with argument above 1")
}
for _, x := range []float64{0., 1.} {
want := x
if dist.P == 0 {
want = 0
}
if dist.Quantile(dist.CDF(x)) != want {
t.Errorf("Quantile(CDF(x)) not equal to %g for x = %g for P = %g", want, x, dist.P)
}
}
expectedQuantile1 := 1.
if dist.P == 0 {
expectedQuantile1 = 0.
}
if dist.Quantile(1) != expectedQuantile1 {
t.Errorf("Quantile at 1 not equal to 1 for P = %g", dist.P)
}
eps := 1e-12
if dist.P > eps && dist.P < 1-eps {
if dist.Quantile(1-dist.P-eps) != 0 {
t.Errorf("Quantile slightly below 0 < 1-P < 1 is not zero")
}
if dist.Quantile(1-dist.P+eps) != 1 {
t.Errorf("Quantile slightly above 0 < 1-P < 1 is not one")
}
if dist.Quantile(1-dist.P) != 0 {
t.Errorf("Quantile at 0 < 1-P < 1 is not zero")
}
}
}