blob: decd181dfc36029563f35b3aac3b8c5f66d02331 [file] [log] [blame]
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package distuv
import (
"math"
"math/rand"
"testing"
"gonum.org/v1/gonum/floats"
)
func TestCategoricalProb(t *testing.T) {
for _, test := range [][]float64{
{1, 2, 3, 0},
} {
dist := NewCategorical(test, nil)
norm := make([]float64, len(test))
floats.Scale(1/floats.Sum(norm), norm)
for i, v := range norm {
p := dist.Prob(float64(i))
if math.Abs(p-v) > 1e-14 {
t.Errorf("Probability mismatch element %d", i)
}
p = dist.Prob(float64(i) + 0.5)
if p != 0 {
t.Errorf("Non-zero probability for non-integer x")
}
}
p := dist.Prob(-1)
if p != 0 {
t.Errorf("Non-zero probability for -1")
}
p = dist.Prob(float64(len(test)))
if p != 0 {
t.Errorf("Non-zero probability for len(test)")
}
}
}
func TestCategoricalRand(t *testing.T) {
for _, test := range [][]float64{
{1, 2, 3, 0},
} {
dist := NewCategorical(test, nil)
nSamples := 2000000
counts := sampleCategorical(t, dist, nSamples)
probs := make([]float64, len(test))
for i := range probs {
probs[i] = dist.Prob(float64(i))
}
same := samedDistCategorical(dist, counts, probs, 1e-2)
if !same {
t.Errorf("Probability mismatch. Want %v, got %v", probs, counts)
}
dist.Reweight(len(test)-1, 10)
counts = sampleCategorical(t, dist, nSamples)
probs = make([]float64, len(test))
for i := range probs {
probs[i] = dist.Prob(float64(i))
}
same = samedDistCategorical(dist, counts, probs, 1e-2)
if !same {
t.Errorf("Probability mismatch after Reweight. Want %v, got %v", probs, counts)
}
w := make([]float64, len(test))
for i := range w {
w[i] = rand.Float64()
}
dist.ReweightAll(w)
counts = sampleCategorical(t, dist, nSamples)
probs = make([]float64, len(test))
for i := range probs {
probs[i] = dist.Prob(float64(i))
}
same = samedDistCategorical(dist, counts, probs, 1e-2)
if !same {
t.Errorf("Probability mismatch after ReweightAll. Want %v, got %v", probs, counts)
}
}
}
func sampleCategorical(t *testing.T, dist Categorical, nSamples int) []float64 {
counts := make([]float64, dist.Len())
for i := 0; i < nSamples; i++ {
v := dist.Rand()
if float64(int(v)) != v {
t.Fatalf("Random number is not an integer")
}
counts[int(v)]++
}
sum := floats.Sum(counts)
floats.Scale(1/sum, counts)
return counts
}
func samedDistCategorical(dist Categorical, counts, probs []float64, tol float64) bool {
same := true
for i, prob := range probs {
if prob == 0 && counts[i] != 0 {
same = false
break
}
if !floats.EqualWithinAbsOrRel(prob, counts[i], tol, tol) {
same = false
break
}
}
return same
}
func TestCategoricalCDF(t *testing.T) {
for _, test := range [][]float64{
{1, 2, 3, 0, 4},
} {
c := make([]float64, len(test))
copy(c, test)
floats.Scale(1/floats.Sum(c), c)
sum := make([]float64, len(test))
floats.CumSum(sum, c)
dist := NewCategorical(test, nil)
cdf := dist.CDF(-0.5)
if cdf != 0 {
t.Errorf("CDF of negative number not zero")
}
for i := range c {
cdf := dist.CDF(float64(i))
if math.Abs(cdf-sum[i]) > 1e-14 {
t.Errorf("CDF mismatch %v. Want %v, got %v.", float64(i), sum[i], cdf)
}
cdfp := dist.CDF(float64(i) + 0.5)
if cdfp != cdf {
t.Errorf("CDF mismatch for non-integer input")
}
}
}
}
func TestCategoricalEntropy(t *testing.T) {
for _, test := range []struct {
weights []float64
entropy float64
}{
{
weights: []float64{1, 1},
entropy: math.Ln2,
},
{
weights: []float64{1, 1, 1, 1},
entropy: math.Log(4),
},
{
weights: []float64{0, 0, 1, 1, 0, 0},
entropy: math.Ln2,
},
} {
dist := NewCategorical(test.weights, nil)
entropy := dist.Entropy()
if math.IsNaN(entropy) || math.Abs(entropy-test.entropy) > 1e-14 {
t.Errorf("Entropy mismatch. Want %v, got %v.", test.entropy, entropy)
}
}
}
func TestCategoricalMean(t *testing.T) {
for _, test := range []struct {
weights []float64
mean float64
}{
{
weights: []float64{10, 0, 0, 0},
mean: 0,
},
{
weights: []float64{0, 10, 0, 0},
mean: 1,
},
{
weights: []float64{1, 2, 3, 4},
mean: 2,
},
} {
dist := NewCategorical(test.weights, nil)
mean := dist.Mean()
if math.IsNaN(mean) || math.Abs(mean-test.mean) > 1e-14 {
t.Errorf("Entropy mismatch. Want %v, got %v.", test.mean, mean)
}
}
}