// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

package privacy

import (
	"config"
	"encoding/csv"
	"fmt"
	"os"
	"sort"
	"strconv"
)

// The parameters which the PrivacyEncoder needs in order to encode observations for a report.
// The config parser will copy these values into the ReportDefinition for the corresponding report.
type PrivacyEncodingParams struct {
	ProbBitFlip    float64
	NumIndexPoints uint32
}

// Constants needed in order to map the PrivacyLevel of a report to a PrivacyEncodingParams.
type privacyConstants struct {
	// The target shuffled epsilon corresponding to each Cobalt PrivacyLevel.
	EpsilonForPrivacyLevel map[config.ReportDefinition_PrivacyLevel]float64
	// The estimated population size.
	population uint64
}

// The key used to look up the PrivacyEncodingParams for a given report from a precomputed table.
// These values are derived from the ReportDefinition and a privacyConstants.
type paramsMapKey struct {
	epsilon    float64
	population uint64
	sparsity   uint64
}

// A container for the sets of values which may appear in a paramsMapKey.
// Each list should be sorted in increasing order.
type paramsKeyLists struct {
	epsilons    []float64
	populations []uint64
	sparsities  []uint64
}

// Maintains a lookup table |paramMap| mapping paramsMapKeys to PrivacyEncodingParams.
// Given a ReportDefinition, finds the best-match key and returns the corresponding PrivacyEncodingParams.
type PrivacyEncodingParamsCalculator struct {
	Constants privacyConstants
	// The lookup table.
	paramMap map[paramsMapKey]PrivacyEncodingParams
	// The epsilon values, population sizes, and sparsities which are mapped in |paramMap|.
	// Any element of the cross product of the lists in |mapped| should be a valid key of |paramMap|.
	mapped paramsKeyLists
}

// Returns a privacyConstants struct with hard-coded values.
// TODO(pesk, azani): Decide how these values should be configured.
func makePrivacyConstants() (pc privacyConstants) {
	pc.EpsilonForPrivacyLevel = map[config.ReportDefinition_PrivacyLevel]float64{
		config.ReportDefinition_LOW_PRIVACY:    10.0,
		config.ReportDefinition_MEDIUM_PRIVACY: 5.0,
		config.ReportDefinition_HIGH_PRIVACY:   1.0,
	}
	pc.population = 10000
	return pc
}

// Public factory method for creating a PrivacyEncodingParamsCalculator from records stored in CSV files.
//
// The file at |paramPath| should contain records of form {epsilon, population, sparsity, prob_bit_flip, num_index_points}.
func NewPrivacyEncodingParamsCalculator(paramPath string) (calc *PrivacyEncodingParamsCalculator, err error) {
	paramRecords, err := readFromCsvFile(paramPath)
	if err != nil {
		return calc, err
	}

	return newPrivacyEncodingParamsCalculatorFromRecords(paramRecords)
}

// Alternative public factory method for creating a PrivacyEncodingParamsCalculator in unit tests.
func NewPrivacyEncodingParamsCalculatorForTesting(paramRecords [][]string) (calc *PrivacyEncodingParamsCalculator, err error) {
	return newPrivacyEncodingParamsCalculatorFromRecords(paramRecords)
}

// Private factory method for creating a PrivacyEncodingParamsCalculator from in-memory records.
// See NewPrivacyEncodingParamsCalculator() for the expected format of the records.
func newPrivacyEncodingParamsCalculatorFromRecords(paramRecords [][]string) (calc *PrivacyEncodingParamsCalculator, err error) {
	pc := makePrivacyConstants()

	m, lists, err := mapPrivacyEncodingParams(paramRecords)
	if err != nil {
		return calc, err
	}

	return &PrivacyEncodingParamsCalculator{Constants: pc, paramMap: m, mapped: lists}, nil
}

// Reads in records from a csv file.
func readFromCsvFile(path string) (records [][]string, err error) {
	info, err := os.Stat(path)
	if err != nil {
		return records, err
	}

	if !info.Mode().IsRegular() {
		return records, fmt.Errorf("%v is not a file.", path)
	}

	f, err := os.Open(path)
	if err != nil {
		return records, err
	}

	file_reader := csv.NewReader(f)
	records, err = file_reader.ReadAll()
	if err != nil {
		return records, err
	}

	return records, nil
}

// Parses |records| and returns a lookup table |m| mapping paramsMapKeys to PrivacyEncodingParams.
// Also returns a paramsKeyLists containing lists of the mapped epsilons, population sizes, and sparsities,
// with each list sorted in increasing order.
func mapPrivacyEncodingParams(records [][]string) (m map[paramsMapKey]PrivacyEncodingParams, lists paramsKeyLists, err error) {
	m = make(map[paramsMapKey]PrivacyEncodingParams)

	epsilons := make(map[float64]bool)
	populations := make(map[uint64]bool)
	sparsities := make(map[uint64]bool)

	for _, record := range records {
		key, params, err := parsePrivacyEncodingRecord(record)
		if err != nil {
			return m, lists, err
		}
		m[key] = params

		epsilons[key.epsilon] = true
		populations[key.population] = true
		sparsities[key.sparsity] = true
	}

	lists = paramsKeyLists{}
	sortAndStoreKeysFloat64(epsilons, &lists.epsilons)
	sortAndStoreKeysUint64(populations, &lists.populations)
	sortAndStoreKeysUint64(sparsities, &lists.sparsities)

	return m, lists, nil
}

// Parses a record into a paramsMapKey and a PrivacyEncodingParams. The expected format of a record is:
// {epsilon, population, sparsity, prob_bit_flip, num_index_points}
func parsePrivacyEncodingRecord(record []string) (key paramsMapKey, params PrivacyEncodingParams, err error) {
	if len(record) != 5 {
		return key, params, fmt.Errorf("wrong record size: %d", len(record))
	}
	epsilon, err := strconv.ParseFloat(record[0], 64)
	if err != nil {
		return key, params, err
	}
	population, err := strconv.ParseUint(record[1], 10, 64)
	if err != nil {
		return key, params, err
	}
	sparsity, err := strconv.ParseUint(record[2], 10, 64)
	if err != nil {
		return key, params, err
	}
	ProbBitFlip, err := strconv.ParseFloat(record[3], 64)
	if err != nil {
		return key, params, err
	}
	NumIndexPoints, err := strconv.ParseUint(record[4], 10, 32)
	if err != nil {
		return key, params, err
	}
	key.epsilon = epsilon
	key.population = population
	key.sparsity = sparsity

	params.ProbBitFlip = ProbBitFlip
	params.NumIndexPoints = uint32(NumIndexPoints)

	return key, params, nil
}

// Extracts the keys from |m|, sorts them in increasing order, and stores them in |vals|.
func sortAndStoreKeysFloat64(m map[float64]bool, vals *[]float64) {
	*vals = make([]float64, 0, len(m))
	for key, _ := range m {
		*vals = append(*vals, key)
	}
	sort.Float64s(*vals)
}

// Extracts the keys from |m|, sorts them in increasing order, and stores them in |vals|.
func sortAndStoreKeysUint64(m map[uint64]bool, vals *[]uint64) {
	*vals = make([]uint64, 0, len(m))
	for key, _ := range m {
		*vals = append(*vals, key)
	}
	sort.Slice(*vals, func(i, j int) bool { return (*vals)[i] < (*vals)[j] })
}

// GetPrivacyEncodingParamsForReport looks up the corresponding PrivacyEncodingParams from |calc|'s
// paramMap, given a |metric| and |report|.
//
// If paramMap does not have a key which exactly matches the values drawn from |metric|, |report|,
// and |calc.constants|, then parameters are returned for the closest key which provides at least as
// much privacy as is required by |report|'s PrivacyLevel. See getBestMappedKey for more details.
func (calc *PrivacyEncodingParamsCalculator) GetPrivacyEncodingParamsForReport(metric *config.MetricDefinition, report *config.ReportDefinition) (params PrivacyEncodingParams, err error) {
	epsilon, ok := calc.Constants.EpsilonForPrivacyLevel[report.PrivacyLevel]
	if !ok {
		return params, fmt.Errorf("no epsilon found for privacy level: %v", report.PrivacyLevel)
	}

	sparsity, err := getSparsityForReport(metric, report)
	if err != nil {
		return params, err
	}

	rangeSize, err := GetIntegerRangeSizeForReport(report)
	if err != nil {
		return params, err
	}

	return calc.GetPrivacyEncodingParams(epsilon, calc.Constants.population, sparsity, rangeSize)
	if err != nil {
		return params, err
	}

	return params, err
}

// Given an |epsilon|, |population|, and |sparsity|, looks up the corresponding
// PrivacyEncodingParams from |calc|'s paramMap.
//
// If |rangeSize| is smaller than the number of index points computed by |calc|, this
// function uses that |rangeSize| as the NumIndexPoints field of the returned
// PrivacyEncodingParams in order to avoid incurring unnecessary rounding error.
//
// If paramMap does not have a key which exactly matches the tuple (|epsilon|, |population|, |sparsity|),
// then parameters are returned for the closest key which provides |epsilon|-differential privacy
// (in the shuffled model) or better. See getBestMappedKey for more details.
func (calc *PrivacyEncodingParamsCalculator) GetPrivacyEncodingParams(epsilon float64, population uint64, sparsity uint64, rangeSize uint64) (params PrivacyEncodingParams, err error) {
	key, err := getBestMappedKey(epsilon, population, sparsity, &calc.mapped)
	if err != nil {
		return params, err
	}

	params, ok := calc.paramMap[key]
	if !ok {
		return params, fmt.Errorf("no params found for key: (epsilon=%f, population=%d, sparsity=%d)", key.epsilon, key.population, key.sparsity)
	}

	if rangeSize < uint64(params.NumIndexPoints) {
		params.NumIndexPoints = uint32(rangeSize)
	}
	return params, nil
}

// Returns the number of valid integer values for |report|. For FleetwideOccurrenceCounts,
// UniqueDeviceNumericStats, and HourlyValueNumericStats reports, this is the number of integers
// in the range [|report.MinValue|, |report.MaxValue|] including both endpoints. For
// FleetwideHistograms and StringCounts reports it is the number of integers in the range
// [0, |report.MaxCount|].
//
// For UniqueDeviceCounts, UniqueDeviceHistograms, and HourlyValueHistograms reports, all observations
// contain the same (implicit or explicit) numeric value of 1, so the returned range size is 1.
//
// A FleetwideMeans report has two separate configured ranges: one for sum values and another for
// count values. The returned range size is the maximum of the two range sizes.
func GetIntegerRangeSizeForReport(report *config.ReportDefinition) (rangeSize uint64, err error) {
	switch report.ReportType {
	case config.ReportDefinition_FLEETWIDE_OCCURRENCE_COUNTS,
		config.ReportDefinition_UNIQUE_DEVICE_NUMERIC_STATS,
		config.ReportDefinition_HOURLY_VALUE_NUMERIC_STATS:
		{
			size := report.MaxValue - report.MinValue + 1
			if size > 0 {
				rangeSize = uint64(size)
			} else {
				return rangeSize, fmt.Errorf("min value %d is larger than max value %d", report.MinValue, report.MaxValue)
			}
		}
	case config.ReportDefinition_FLEETWIDE_HISTOGRAMS,
		config.ReportDefinition_STRING_COUNTS:
		rangeSize, err = report.MaxCount+1, nil
	case config.ReportDefinition_FLEETWIDE_MEANS:
		{
			sumSize := report.MaxValue - report.MinValue + 1
			if sumSize > 0 {
				rangeSize = uint64(sumSize)
			} else {
				return rangeSize, fmt.Errorf("min value %d is larger than max value %d", report.MinValue, report.MaxValue)
			}
			if rangeSize < report.MaxCount {
				rangeSize = report.MaxCount + 1
			}
		}
	case config.ReportDefinition_UNIQUE_DEVICE_COUNTS,
		config.ReportDefinition_UNIQUE_DEVICE_HISTOGRAMS,
		config.ReportDefinition_HOURLY_VALUE_HISTOGRAMS:
		rangeSize = 1
	default:
		return rangeSize, fmt.Errorf("unsupported ReportType: %v", report.ReportType)
	}
	return rangeSize, nil

}

// Returns the sparsity for a given |metric| and |report|. This is the max number of elements in the index vector representation
// of a contribution for |report|.
func getSparsityForReport(metric *config.MetricDefinition, report *config.ReportDefinition) (sparsity uint64, err error) {
	numEventVectors, err := getNumEventVectorsPerContribution(metric, report)
	if err != nil {
		return sparsity, err
	}

	numBuckets, err := getNumBucketsPerEventVector(metric, report)
	if err != nil {
		return sparsity, err
	}

	return numEventVectors * numBuckets, nil
}

// Returns the max number of event vectors which may be included in a contribution for |report|.
func getNumEventVectorsPerContribution(metric *config.MetricDefinition, report *config.ReportDefinition) (numEventVectors uint64, err error) {
	switch report.ReportType {
	case config.ReportDefinition_UNIQUE_DEVICE_COUNTS:
		switch report.LocalAggregationProcedure {
		case
			config.ReportDefinition_SELECT_FIRST,
			config.ReportDefinition_SELECT_MOST_COMMON:
			numEventVectors, err = 1, nil
		case config.ReportDefinition_AT_LEAST_ONCE:
			numEventVectors, err = getEventCodeBufferMax(metric), nil
		default:
			err = fmt.Errorf("unexpected LocalAggregationProcedure: %v", report.LocalAggregationProcedure)
		}
	case
		config.ReportDefinition_FLEETWIDE_OCCURRENCE_COUNTS,
		config.ReportDefinition_UNIQUE_DEVICE_HISTOGRAMS,
		config.ReportDefinition_HOURLY_VALUE_HISTOGRAMS,
		config.ReportDefinition_FLEETWIDE_HISTOGRAMS,
		config.ReportDefinition_FLEETWIDE_MEANS,
		config.ReportDefinition_UNIQUE_DEVICE_NUMERIC_STATS,
		config.ReportDefinition_HOURLY_VALUE_NUMERIC_STATS,
		config.ReportDefinition_STRING_COUNTS:
		numEventVectors, err = getEventCodeBufferMax(metric), nil

	default:
		err = fmt.Errorf("unsupported ReportType: %v", report.ReportType)
	}
	return numEventVectors, err
}

// Returns the maximum number of event vectors for which a device stores local aggregates.
// This is either the event_code_buffer_max value specified in the MetricDefinition, or
// (if that field is unset) the total number of valid event vectors.
func getEventCodeBufferMax(metric *config.MetricDefinition) (bufferMax uint64) {
	if metric.EventCodeBufferMax != 0 {
		return metric.EventCodeBufferMax
	}
	return numEventVectors(metric)
}

// Returns the total number of valid event vectors for a MetricDefinition.
func numEventVectors(metric *config.MetricDefinition) (numEventVectors uint64) {
	numEventVectors = 1
	for _, dim := range metric.MetricDimensions {
		numEventVectors *= uint64(numEventCodes(dim))
	}
	return numEventVectors
}

// A helper function returning the number of valid event codes for a MetricDimension.
func numEventCodes(dim *config.MetricDefinition_MetricDimension) (numEventCodes uint32) {
	if dim.MaxEventCode != 0 {
		return dim.MaxEventCode + 1
	}
	return uint32(len(dim.EventCodes))
}

// Returns the max number of buckets which may be populated in a contribution for |report| for each event vector defined in |metric|.
// Returns 1 for reports for which a contribution consists of a single integer per event vector.
func getNumBucketsPerEventVector(metric *config.MetricDefinition, report *config.ReportDefinition) (numBuckets uint64, err error) {
	switch metric.MetricType {
	case config.MetricDefinition_OCCURRENCE:
		numBuckets, err = 1, nil
	case config.MetricDefinition_INTEGER:
		switch report.ReportType {
		case config.ReportDefinition_FLEETWIDE_HISTOGRAMS:
			numBuckets, err = getNumHistogramBuckets(report.IntBuckets)
		// An Observation for FLEETWIDE_MEANS is equivalent to a histogram with 2 buckets per event code:
		// one representing the sum, the other representing the count.
		case config.ReportDefinition_FLEETWIDE_MEANS:
			numBuckets, err = 2, nil
		case
			config.ReportDefinition_UNIQUE_DEVICE_HISTOGRAMS,
			config.ReportDefinition_HOURLY_VALUE_HISTOGRAMS,
			config.ReportDefinition_UNIQUE_DEVICE_NUMERIC_STATS,
			config.ReportDefinition_HOURLY_VALUE_NUMERIC_STATS:
			numBuckets, err = 1, nil
		}
	case config.MetricDefinition_INTEGER_HISTOGRAM:
		numBuckets, err = getNumHistogramBuckets(metric.IntBuckets)
	case config.MetricDefinition_STRING:
		numBuckets, err = numBuckets, fmt.Errorf("STRING metrics are not supported yet")
	default:
		err = fmt.Errorf("unsupported metric type %v", metric.MetricType)
	}
	return numBuckets, err
}

// Returns the number of histogram buckets in an IntegerBuckets.
func getNumHistogramBuckets(buckets *config.IntegerBuckets) (numBuckets uint64, err error) {
	switch buckets.Buckets.(type) {
	case *config.IntegerBuckets_Exponential:
		numBuckets, err = uint64(buckets.GetExponential().GetNumBuckets()), nil
	case *config.IntegerBuckets_Linear:
		numBuckets, err = uint64(buckets.GetLinear().GetNumBuckets()), nil
	case nil:
		err = fmt.Errorf("IntegerBuckets type not set")
	default:
		err = fmt.Errorf("unexpected IntegerBuckets type")
	}
	return numBuckets, err
}

// Returns the paramsMapKey{e, p, s} with the following properties:
// - |e| is the greatest mapped epsilon value which is *less than or equal* to |epsilon|
// - |p| is the greatest mapped population value which is *less than or equal to* |population|
// - |s| is the least mapped sparsity value which is *greater than or equal to* |sparsity|
// or returns an error if this is not possible (e.g. |epsilon| is smaller than all mapped epsilon values).
//
// This ensures that the returned parameters will provide |epsilon| (shuffled) differential privacy, or better.
func getBestMappedKey(epsilon float64, population uint64, sparsity uint64, mapped *paramsKeyLists) (key paramsMapKey, err error) {
	e, err := mapped.getBestMappedEpsilon(epsilon)
	if err != nil {
		return key, err
	}

	p, err := mapped.getBestMappedPopulation(population)
	if err != nil {
		return key, err
	}

	s, err := mapped.getBestMappedSparsity(sparsity)
	if err != nil {
		return key, err
	}

	return paramsMapKey{e, p, s}, nil
}

// If |epsilon| is smaller than the least epsilon value in |lists|, returns a non-nil error.
// Otherwise, the returned |mappedEpsilon| is the greatest epsilon value in |lists| which is less than or equal to |epsilon|.
func (lists *paramsKeyLists) getBestMappedEpsilon(epsilon float64) (mappedEpsilon float64, err error) {
	if len(lists.epsilons) == 0 {
		return mappedEpsilon, fmt.Errorf("list of mapped epsilon values is empty")
	}
	i := sort.Search(len(lists.epsilons), func(i int) bool { return lists.epsilons[i] > epsilon })

	if i > 0 {
		return lists.epsilons[i-1], nil
	} else {
		return mappedEpsilon, fmt.Errorf("input epsilon %v is outside the valid range", epsilon)
	}
}

// If |population| is smaller than the least population value in |lists|, returns a non-nil error.
// Otherwise, the returned |mappedPopulation| is the greatest population value in |lists| which is less than or equal to |population|.
func (lists *paramsKeyLists) getBestMappedPopulation(population uint64) (mappedPopulation uint64, err error) {
	if len(lists.populations) == 0 {
		return mappedPopulation, fmt.Errorf("list of mapped population values is empty")
	}
	i := sort.Search(len(lists.populations), func(i int) bool { return lists.populations[i] > population })

	if i > 0 {
		return lists.populations[i-1], nil
	} else {
		return mappedPopulation, fmt.Errorf("input population %v is outside the valid range", population)
	}
}

// If |sparsity| is larger than the greatest sparsity value in |lists|, returns a non-nil error.
// Otherwise, the returned |mappedSparsity| is the least sparsity value in |lists| which is greater than or equal to |sparsity|.
func (lists *paramsKeyLists) getBestMappedSparsity(sparsity uint64) (mappedSparsity uint64, err error) {
	if len(lists.sparsities) == 0 {
		return mappedSparsity, fmt.Errorf("list of mapped sparsity values is empty")
	}
	i := sort.Search(len(lists.sparsities), func(i int) bool { return lists.sparsities[i] >= sparsity })

	if i < len(lists.sparsities) {
		return lists.sparsities[i], nil
	} else {
		return mappedSparsity, fmt.Errorf("input sparsity %v is outside the valid range", sparsity)
	}
}
