blob: 923340a8560edbea01399775229e379d999d285c [file] [log] [blame]
//
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package main
import (
"fmt"
"io/ioutil"
"sort"
"strconv"
"strings"
"github.com/apache/beam/sdks/v2/go/pkg/beam/io/textio"
"github.com/google/differential-privacy/privacy-on-beam/v2/codelab"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/plotutil"
"gonum.org/v1/plot/vg"
)
const (
// Constants to differentiate between examples.
count = "count"
mean = "mean"
sum = "sum"
publicPartitions = "public_partitions"
)
func drawPlot(hourToValue, dpHourToValue map[int]float64, example, nonDPOutput, dpOutput string) error {
// Sort dp and non-dp points.
keys := make([]int, 0)
for k := range hourToValue {
keys = append(keys, k)
}
sort.Ints(keys)
points := make([]float64, 0)
for _, k := range keys {
points = append(points, hourToValue[k])
}
dpKeys := make([]int, 0)
for k := range dpHourToValue {
dpKeys = append(dpKeys, k)
}
sort.Ints(dpKeys)
dpPoints := make([]float64, 0)
for _, k := range dpKeys {
dpPoints = append(dpPoints, dpHourToValue[k])
}
p := plot.New()
p.X.Label.Text = "Hour"
switch example {
case count, publicPartitions: // count & publicPartitions both compute visits per hour.
p.Y.Label.Text = "Visits"
p.Title.Text = "Visits Per Hour"
case mean:
p.Y.Label.Text = "Time Spent"
p.Title.Text = "Mean Time Spent"
case sum:
p.Y.Label.Text = "Revenue"
p.Title.Text = "Revenue Per Hour"
default:
return fmt.Errorf("unknown example %q specified, please use one of 'count', 'sum', 'mean', 'public_partitions'", example)
}
w := vg.Points(20)
// Non-DP Plot
bars, err := plotter.NewBarChart(plotter.Values(points), w)
if err != nil {
return fmt.Errorf("could not create bars from points %v: %v", plotter.Values(points), err)
}
bars.LineStyle.Width = vg.Length(0)
bars.Color = plotutil.Color(2)
p.Add(bars)
p.NominalX("0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23")
// Save non-dp plot.
if err := p.Save(10*vg.Inch, 5*vg.Inch, nonDPOutput); err != nil {
return fmt.Errorf("Could not save plot: %v", err)
}
// DP Plot
dpBars, err := plotter.NewBarChart(plotter.Values(dpPoints), w)
if err != nil {
return fmt.Errorf("could not create bars from points %v: %v", plotter.Values(dpPoints), err)
}
dpBars.LineStyle.Width = vg.Length(0)
dpBars.Color = plotutil.Color(3)
dpBars.Offset = w
p.Add(dpBars)
p.Legend.Add("Raw", bars)
p.Legend.Add("Private", dpBars)
p.Legend.Top = true
// Save dp plot.
if err := p.Save(15*vg.Inch, 5*vg.Inch, dpOutput); err != nil {
return fmt.Errorf("Could not save plot: %v", err)
}
return nil
}
// readInput reads from a .csv file detailing visits to a restaurant in the form
// of "visitor_id, visit time, minutes spent, money spent" and returns a
// PCollection of Visit structs.
func readInput(s beam.Scope, input string) beam.PCollection {
s = s.Scope("readInput")
lines := textio.Read(s, input)
return beam.ParDo(s, codelab.CreateVisitsFn, lines)
}
func writeOutput(s beam.Scope, output beam.PCollection, outputTextName string) {
s = s.Scope("writeOutput")
output = beam.ParDo(s, convertToPairFn, output)
formattedOutput := beam.Combine(s, &normalizeOutputCombineFn{}, output)
textio.Write(s, outputTextName, formattedOutput)
}
// readOutput reads from a .txt file where each line has an hour (int) associated with
// a value (float64) separated by a whitespace and returns a map of these hour to value
// pairs.
// Returns an error if there is an error reading the output file.
func readOutput(output string) (map[int]float64, error) {
hourToValue := make(map[int]float64)
contents, err := ioutil.ReadFile(output)
if err != nil {
return nil, fmt.Errorf("could not read output file %s", output)
}
lines := strings.Split(string(contents), "\n")
for _, line := range lines {
if line == "" {
continue
}
elements := strings.Split(line, " ")
if len(elements) != 2 {
return nil, fmt.Errorf("got %d number of elements in line %q, expected 2", len(elements), line)
}
hour, err := strconv.Atoi(elements[0])
if err != nil {
return nil, fmt.Errorf("could not convert hour %s to int: %v", elements[0], err)
}
value, err := strconv.ParseFloat(elements[1], 64)
if err != nil {
return nil, fmt.Errorf("could not convert value %s to float64: %v", elements[1], err)
}
hourToValue[hour] = value
}
return hourToValue, nil
}