// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

package privacy

import (
	"config"
	"fmt"
	"math"

	"gonum.org/v1/gonum/integrate/quad"
)

type ErrorCalculator struct {
	ParamsCalc PrivacyEncodingParamsCalculator
}

// Public factory method for creating an ErrorCalculator given a
// PrivacyEncodingParamsCalculator.
func NewErrorCalculator(paramsCalc PrivacyEncodingParamsCalculator) *ErrorCalculator {
	return &ErrorCalculator{paramsCalc}
}

// Public factory method for creating an ErrorCalculator given the file path
// of the PrivacyEncodingParams.
func NewErrorCalculatorFromPrivacyParams(privacyParamsPath string) (*ErrorCalculator, error) {
	paramsCalculator, err := NewPrivacyEncodingParamsCalculator(privacyParamsPath)
	if err != nil {
		return nil, err
	}
	errorCalculator := NewErrorCalculator(*paramsCalculator)
	if err != nil {
		return nil, err
	}
	return errorCalculator, nil
}

// Given a |metric|, |report|, and |params|, estimates the report row error.
func (e *ErrorCalculator) Estimate(metric *config.MetricDefinition, report *config.ReportDefinition, epsilon float64, population uint64, minDenominatorEstimate uint64) (estimate float64, err error) {
	sparsity, err := getSparsityForReport(metric, report)
	if err != nil {
		return -1, err
	}

	populationConstant := e.ParamsCalc.Constants.population
	privacyEncodingParams, err := e.ParamsCalc.GetPrivacyEncodingParams(epsilon, populationConstant, sparsity)
	if err != nil {
		return -1, err
	}

	// For report types with an hourly contribution interval, the population is set to (24*population) pseudo-users.
	probBitFlip := privacyEncodingParams.ProbBitFlip
	discretization := uint64(privacyEncodingParams.NumIndexPoints)
	var errorEstimate float64
	switch report.GetReportType() {
	case config.ReportDefinition_UNIQUE_DEVICE_HISTOGRAMS:
		fallthrough // Calculate RMSE Error for each bucket.
	case config.ReportDefinition_UNIQUE_DEVICE_COUNTS:
		errorEstimate = SingleContributionRapporRMSE(population, probBitFlip)
	case config.ReportDefinition_HOURLY_VALUE_HISTOGRAMS:
		errorEstimate = SingleContributionRapporRMSE(24*population, probBitFlip)
	case config.ReportDefinition_FLEETWIDE_OCCURRENCE_COUNTS:
		contributionRange := uint64(report.MaxValue - report.MinValue)
		errorEstimate = MultipleContributionRapporRMSE(24*population, probBitFlip, contributionRange, discretization)
	case config.ReportDefinition_FLEETWIDE_HISTOGRAMS:
		contributionRange := report.MaxCount
		errorEstimate = MultipleContributionRapporRMSE(24*population, probBitFlip, contributionRange, discretization)
	// TODO(jaredweinstein): Enable NumericStats once the formula is resolved.
	// case config.ReportDefinition_UNIQUE_DEVICE_NUMERIC_STATS:
	//         if err := meanReportConfigurationError(report, minDenominatorEstimate); err != nil {
	//                 return -1, err
	//         }
	//         errorEstimate = NumericStatsRapporRMSE(population, privacyEncodingParams, minDenominatorEstimate, report)
	// case config.ReportDefinition_HOURLY_VALUE_NUMERIC_STATS:
	//         if err := meanReportConfigurationError(report, minDenominatorEstimate); err != nil {
	//                 return -1, err
	//         }
	//         errorEstimate = NumericStatsRapporRMSE(24*population, privacyEncodingParams, minDenominatorEstimate, report)
	case config.ReportDefinition_FLEETWIDE_MEANS:
		if err := meanReportConfigurationError(report, minDenominatorEstimate); err != nil {
			return -1, err
		}
		errorEstimate = FleetwideMeansRapporRMSE(24*population, probBitFlip, minDenominatorEstimate, discretization, report)
	default:
		reportType := config.ReportDefinition_ReportType_name[int32(report.GetReportType())]
		return -1, fmt.Errorf("Error estimation is not supported for reports of type %s.", reportType)
	}

	if math.IsNaN(errorEstimate) || math.IsInf(errorEstimate, 0) {
		return errorEstimate, fmt.Errorf("Error estimation failed to return valid result due to an invalid or missing field.")
	}
	return errorEstimate, nil
}

// Compute the 2D-RAPPOR RMSE for reports with single contributions.
//
// See Proposition 1 and 2 of go/histogram-aggregation-privacy-guarantee.
func SingleContributionRapporRMSE(population uint64, probBitFlip float64) float64 {
	return math.Sqrt(float64(population)*probBitFlip*(1-probBitFlip)) / (1 - 2*probBitFlip)
}

// Compute 2D-RAPPOR RMSE for reports with multiple or real-number contributions.
//
// See Proposition 1, 2, and 3 of go/histogram-aggregation-privacy-guarantee.
func MultipleContributionRapporRMSE(population uint64, probBitFlip float64, contributionRange uint64, discretization uint64) float64 {
	var n = float64(population)
	var p = probBitFlip
	var m = float64(contributionRange)
	var r = float64(discretization)

	var estimate = n * p * (1 - p) / math.Pow((1-2*p), 2)
	estimate = estimate * (2*math.Pow(r, 3) + 3*math.Pow(r, 2) + r) / 6
	estimate = estimate + (n / 4)
	estimate = m / r * math.Sqrt(estimate)
	return estimate
}

// TODO(jaredweinstein): Add back NumericStats once the formula is resolved.
// func NumericStatsRapporRMSE(population uint64, params PrivacyEncodingParams, minDenominatorEstimate uint64, report *config.ReportDefinition) float64 {
//         sigma := SingleContributionRapporRMSE(population, params)
//         a := params.ProbBitFlip / (1 - 2*params.ProbBitFlip)
//         M := uint64(math.Max(math.Abs(float64(report.MinValue)), math.Abs(float64(report.MaxValue))))
//
//         // Numerator is aggregated like FleetwideOccurrenceCounts.
//         rmseNumerator := MultipleContributionRapporRMSE(population, params, M)
//         mseNumerator := math.Pow(rmseNumerator, 2)
//
//         // Denominator is aggregated like UniqueDeviceCounts.
//         rmseDenominator := SingleContributionRapporRMSE(population, params)
//         mseDenominator := math.Pow(rmseDenominator, 2)
//
//         return meanRapporRMSE(mseNumerator, mseDenominator, sigma, a, float64(M), float64(minDenominatorEstimate))
// }

func FleetwideMeansRapporRMSE(population uint64, probBitFlip float64, minDenominatorEstimate uint64, discretization uint64, report *config.ReportDefinition) float64 {
	t := report.MaxCount
	a := probBitFlip * float64(t) / (1 - 2*probBitFlip)
	T := float64(math.Max(float64(-report.MinValue), float64(report.MaxValue)))

	// Numerator is aggregated like FleetwideOccurrenceCounts.
	numeratorContributionRange := uint64(report.MaxValue - report.MinValue)
	rmseNumerator := MultipleContributionRapporRMSE(population, probBitFlip, numeratorContributionRange, discretization)
	mseNumerator := math.Pow(rmseNumerator, 2)

	// Denominator RMSE aggregated like FleetwideHistograms.
	sigma := MultipleContributionRapporRMSE(population, probBitFlip, t, discretization)

	return meanRapporRMSE(mseNumerator, sigma, a, T, float64(minDenominatorEstimate))
}

func meanRapporRMSE(mseNum float64, sigma float64, a float64, T float64, B float64) float64 {
	return math.Sqrt(mseNum*eOne(B, sigma, a, 1) +
		math.Pow(T, 2)*eTwo(B, sigma, a, 1) +
		math.Pow(T/B, 2)*math.Pow(sigma, 2))
}

// E1 term defined in Lemma 10.
func eOne(B float64, sigma float64, a float64, l float64) float64 {
	f := func(y float64) float64 {
		return (2 / math.Pow(B+y, 3)) *
			math.Exp(-math.Pow(sigma, 2)/
				math.Pow(a, 2)*
				h(-a*y/math.Pow(sigma, 2)))
	}
	return 1/math.Pow(B, 2) + quad.Fixed(f, l-B, 0, 10000, nil, 0)
}

// E2 term defined in Lemma 10.
func eTwo(B float64, sigma float64, a float64, l float64) float64 {
	f := func(y float64) float64 {
		return (2 * B * y /
			math.Pow(B+y, 3)) *
			math.Exp(-math.Pow(sigma, 2)/
				math.Pow(a, 2)*
				h(-a*y/math.Pow(sigma, 2)))
	}
	return -quad.Fixed(f, l-B, 0, 10000, nil, 0)
}

func h(u float64) float64 {
	return (1+u)*math.Log(1+u) - u
}

func meanReportConfigurationError(report *config.ReportDefinition, minDenominatorEstimate uint64) error {
	if minDenominatorEstimate == 0 {
		return fmt.Errorf("user estimate for lower bound on unnoised denominator required for %s", report.GetReportType())
	}
	if report.MaxValue == 0 && report.MinValue == 0 {
		return fmt.Errorf("MinValue and MaxValue required to estimate error for %s", report.GetReportType())
	}
	return nil
}
