blob: 24f69de9f2ec9155afe597824061692cb94cc528 [file] [log] [blame]
 // Copyright ©2016 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package distmv import ( "math" "testing" "golang.org/x/exp/rand" "gonum.org/v1/gonum/mat" ) func TestDirichlet(t *testing.T) { // Data from Scipy. for cas, test := range []struct { Dir *Dirichlet x []float64 prob float64 }{ { NewDirichlet([]float64{1, 1, 1}, nil), []float64{0.2, 0.3, 0.5}, 2.0, }, { NewDirichlet([]float64{0.6, 10, 8.7}, nil), []float64{0.2, 0.3, 0.5}, 0.24079612737071665, }, } { p := test.Dir.Prob(test.x) if math.Abs(p-test.prob) > 1e-14 { t.Errorf("Probablility mismatch. Case %v. Got %v, want %v", cas, p, test.prob) } } rnd := rand.New(rand.NewSource(1)) for cas, test := range []struct { Dir *Dirichlet }{ { NewDirichlet([]float64{1, 1, 1}, rnd), }, { NewDirichlet([]float64{2, 3}, rnd), }, { NewDirichlet([]float64{0.2, 0.3}, rnd), }, { NewDirichlet([]float64{0.2, 4}, rnd), }, { NewDirichlet([]float64{0.1, 4, 20}, rnd), }, } { const n = 1e5 d := test.Dir dim := d.Dim() x := mat.NewDense(n, dim, nil) generateSamples(x, d) checkMean(t, cas, x, d, 1e-2) checkCov(t, cas, x, d, 1e-2) } }