blob: 13a26b8e8009247720d9c79292cc52cf28d79de2 [file] [log] [blame]
 // Copyright ©2015 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package distuv import ( "math" "golang.org/x/exp/rand" ) // Categorical is an extension of the Bernoulli distribution where x takes // values {0, 1, ..., len(w)-1} where w is the weight vector. Categorical must // be initialized with NewCategorical. type Categorical struct { weights []float64 // heap is a weight heap. // // It keeps a heap-organised sum of remaining // index weights that are available to be taken // from. // // Each element holds the sum of weights for // the corresponding index, plus the sum of // its children's weights; the children of // an element i can be found at positions // 2*(i+1)-1 and 2*(i+1). The root of the // weight heap is at element 0. // // See comments in container/heap for an // explanation of the layout of a heap. heap []float64 src rand.Source } // NewCategorical constructs a new categorical distribution where the probability // that x equals i is proportional to w[i]. All of the weights must be // nonnegative, and at least one of the weights must be positive. func NewCategorical(w []float64, src rand.Source) Categorical { c := Categorical{ weights: make([]float64, len(w)), heap: make([]float64, len(w)), src: src, } c.ReweightAll(w) return c } // CDF computes the value of the cumulative density function at x. func (c Categorical) CDF(x float64) float64 { var cdf float64 for i, w := range c.weights { if x < float64(i) { break } cdf += w } return cdf / c.heap[0] } // Entropy returns the entropy of the distribution. func (c Categorical) Entropy() float64 { var ent float64 for _, w := range c.weights { if w == 0 { continue } p := w / c.heap[0] ent += p * math.Log(p) } return -ent } // Len returns the number of values x could possibly take (the length of the // initial supplied weight vector). func (c Categorical) Len() int { return len(c.weights) } // Mean returns the mean of the probability distribution. func (c Categorical) Mean() float64 { var mean float64 for i, v := range c.weights { mean += float64(i) * v } return mean / c.heap[0] } // Prob computes the value of the probability density function at x. func (c Categorical) Prob(x float64) float64 { xi := int(x) if float64(xi) != x { return 0 } if xi < 0 || xi > len(c.weights)-1 { return 0 } return c.weights[xi] / c.heap[0] } // LogProb computes the natural logarithm of the value of the probability density function at x. func (c Categorical) LogProb(x float64) float64 { return math.Log(c.Prob(x)) } // Rand returns a random draw from the categorical distribution. func (c Categorical) Rand() float64 { var r float64 if c.src == nil { r = c.heap[0] * rand.Float64() } else { r = c.heap[0] * rand.New(c.src).Float64() } i := 1 last := -1 left := len(c.weights) for { if r -= c.weights[i-1]; r <= 0 { break // Fall within item i-1. } i <<= 1 // Move to left child. if d := c.heap[i-1]; r > d { r -= d // If enough r to pass left child, // move to right child state will // be caught at break above. i++ } if i == last || left < 0 { panic("categorical: bad sample") } last = i left-- } return float64(i - 1) } // Reweight sets the weight of item idx to w. The input weight must be // non-negative, and after reweighting at least one of the weights must be // positive. func (c Categorical) Reweight(idx int, w float64) { if w < 0 { panic("categorical: negative weight") } w, c.weights[idx] = c.weights[idx]-w, w idx++ for idx > 0 { c.heap[idx-1] -= w idx >>= 1 } if c.heap[0] <= 0 { panic("categorical: sum of the weights non-positive") } } // ReweightAll resets the weights of the distribution. ReweightAll panics if // len(w) != c.Len. All of the weights must be nonnegative, and at least one of // the weights must be positive. func (c Categorical) ReweightAll(w []float64) { if len(w) != c.Len() { panic("categorical: length of the slices do not match") } for _, v := range w { if v < 0 { panic("categorical: negative weight") } } copy(c.weights, w) c.reset() } func (c Categorical) reset() { copy(c.heap, c.weights) for i := len(c.heap) - 1; i > 0; i-- { // Sometimes 1-based counting makes sense. c.heap[((i+1)>>1)-1] += c.heap[i] } // TODO(btracey): Renormalization for weird weights? if c.heap[0] <= 0 { panic("categorical: sum of the weights non-positive") } }