blob: fa8d0f653bf0f6750ea46db10d8470eb8972afe0 [file] [log] [blame]
// Copyright ©2016 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stat
import (
"math"
"sort"
)
// ROC returns paired false positive rate (FPR) and true positive rate
// (TPR) values corresponding to cutoff points on the receiver operator
// characteristic (ROC) curve obtained when y is treated as a binary
// classifier for classes with weights. The cutoff thresholds used to
// calculate the ROC are returned in thresh such that tpr[i] and fpr[i]
// are the true and false positive rates for y >= thresh[i].
//
// The input y and cutoffs must be sorted, and values in y must correspond
// to values in classes and weights. SortWeightedLabeled can be used to
// sort y together with classes and weights.
//
// For a given cutoff value, observations corresponding to entries in y
// greater than the cutoff value are classified as false, while those
// less than or equal to the cutoff value are classified as true. These
// assigned class labels are compared with the true values in the classes
// slice and used to calculate the FPR and TPR.
//
// If weights is nil, all weights are treated as 1.
//
// If cutoffs is nil or empty, all possible cutoffs are calculated,
// resulting in fpr and tpr having length one greater than the number of
// unique values in y. Otherwise fpr and tpr will be returned with the
// same length as cutoffs. floats.Span can be used to generate equally
// spaced cutoffs.
//
// More details about ROC curves are available at
// https://en.wikipedia.org/wiki/Receiver_operating_characteristic
func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) {
if len(y) != len(classes) {
panic("stat: slice length mismatch")
}
if weights != nil && len(y) != len(weights) {
panic("stat: slice length mismatch")
}
if !sort.Float64sAreSorted(y) {
panic("stat: input must be sorted ascending")
}
if !sort.Float64sAreSorted(cutoffs) {
panic("stat: cutoff values must be sorted ascending")
}
if len(y) == 0 {
return nil, nil, nil
}
if len(cutoffs) == 0 {
if cutoffs == nil || cap(cutoffs) < len(y)+1 {
cutoffs = make([]float64, len(y)+1)
} else {
cutoffs = cutoffs[:len(y)+1]
}
cutoffs[0] = math.Inf(-1)
// Choose all possible cutoffs for unique values in y.
bin := 1
cutoffs[bin] = y[0]
for i, u := range y[1:] {
if u == y[i] {
continue
}
bin++
cutoffs[bin] = u
}
cutoffs = cutoffs[:bin+1]
} else {
// Don't mutate the provided cutoffs.
tmp := cutoffs
cutoffs = make([]float64, len(cutoffs))
copy(cutoffs, tmp)
}
tpr = make([]float64, len(cutoffs))
fpr = make([]float64, len(cutoffs))
var bin int
var nPos, nNeg float64
for i, u := range classes {
// Update the bin until it matches the next y value
// skipping empty bins.
for bin < len(cutoffs)-1 && y[i] > cutoffs[bin] {
bin++
tpr[bin] = tpr[bin-1]
fpr[bin] = fpr[bin-1]
}
posWeight, negWeight := 1.0, 0.0
if weights != nil {
posWeight = weights[i]
}
if !u {
posWeight, negWeight = negWeight, posWeight
}
nPos += posWeight
nNeg += negWeight
if y[i] <= cutoffs[bin] {
tpr[bin] += posWeight
fpr[bin] += negWeight
}
}
invNeg := 1 / nNeg
invPos := 1 / nPos
for i := range tpr {
// Prevent fused float operations by
// making explicit float64 conversions.
tpr[i] = 1 - float64(tpr[i]*invPos)
fpr[i] = 1 - float64(fpr[i]*invNeg)
}
for i, j := 0, len(tpr)-1; i < j; i, j = i+1, j-1 {
tpr[i], tpr[j] = tpr[j], tpr[i]
fpr[i], fpr[j] = fpr[j], fpr[i]
}
for i, j := 1, len(cutoffs)-1; i < j; i, j = i+1, j-1 {
cutoffs[i], cutoffs[j] = cutoffs[j], cutoffs[i]
}
cutoffs[0] = math.Inf(1)
return tpr, fpr, cutoffs
}