blob: f95f6ae98d079309c127a496af51c119d5c2fb8e [file] [log] [blame]
//
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file contains logic for handling public partitions.
package pbeam
import (
"bytes"
"encoding/base64"
"fmt"
"reflect"
"github.com/google/differential-privacy/privacy-on-beam/v3/internal/kv"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/register"
)
func init() {
register.Combiner2[pMap, beam.W](&partitionMapFn{})
register.DoFn2x3[beam.U, kv.Pair, beam.W, kv.Pair, error](&encodeIDVFn{})
register.DoFn2x3[beam.W, kv.Pair, beam.U, kv.Pair, error](&decodeIDVFn{})
register.DoFn3x1[beam.U, beam.W, func(beam.U, beam.W), error](&prunePartitionsInMemoryVFn{})
register.Emitter2[beam.U, beam.W]()
register.DoFn3x0[beam.U, kv.Pair, func(beam.U, kv.Pair)](&prunePartitionsInMemoryKVFn{})
register.Emitter2[beam.U, kv.Pair]()
register.Function1x2[beam.W, beam.W, int64](addZeroValuesToPublicPartitionsInt64)
register.Function1x2[beam.W, beam.W, float64](addZeroValuesToPublicPartitionsFloat64)
register.Function1x2[beam.W, beam.W, []float64](addEmptySliceToPublicPartitionsFloat64)
register.Function4x1[beam.U, kv.Pair, func(*pMap) bool, func(beam.U, kv.Pair), error](prunePartitionsKV)
register.Function4x0[beam.V, func(*int64) bool, func(*beam.U) bool, func(beam.U, beam.V)](mergePublicValues)
register.Iter1[int64]()
register.Iter1[beam.U]()
register.Emitter2[beam.U, beam.V]()
register.Function4x0[beam.W, func(*beam.V) bool, func(*beam.V) bool, func(beam.W, beam.V)](mergeResultWithEmptyPublicPartitionsFn)
register.Iter1[beam.V]()
register.Emitter2[beam.W, beam.V]()
}
// newAddZeroValuesToPublicPartitionsFn turns a PCollection<V> into PCollection<V,0>.
func newAddZeroValuesToPublicPartitionsFn(vKind reflect.Kind) (any, error) {
switch vKind {
case reflect.Int64:
return addZeroValuesToPublicPartitionsInt64, nil
case reflect.Float64:
return addZeroValuesToPublicPartitionsFloat64, nil
default:
return nil, fmt.Errorf("vKind(%v) should be int64 or float64", vKind)
}
}
func addZeroValuesToPublicPartitionsInt64(partition beam.W) (k beam.W, v int64) {
return partition, 0
}
func addZeroValuesToPublicPartitionsFloat64(partition beam.W) (k beam.W, v float64) {
return partition, 0
}
func addEmptySliceToPublicPartitionsFloat64(partition beam.W) (k beam.W, v []float64) {
return partition, []float64{}
}
// pMap holds a set of partition keys for quick lookup as a map from string to bool.
// Key is the base64 string representation of encoded partition key.
// Value is set to true if partition key exists.
type pMap map[string]bool
// dropNonPublicPartitions returns the PCollection with the non-public partitions dropped if public partitions are
// specified. Returns the input PCollection otherwise.
func dropNonPublicPartitions(s beam.Scope, pcol PrivatePCollection, publicPartitions any, partitionType reflect.Type) (beam.PCollection, error) {
// Obtain type information from the underlying PCollection<K,V>.
idT, _ := beam.ValidateKVType(pcol.col)
// If PublicPartitions is not specified, return the input collection.
if publicPartitions == nil {
return pcol.col, nil
}
// Drop non-public partitions, if public partitions are specified as a PCollection.
if publicPartitionscCol, ok := publicPartitions.(beam.PCollection); ok {
// Data is <PrivacyKey, PartitionKey, Value>
if pcol.codec != nil {
return dropNonPublicPartitionsKVFn(s, publicPartitionscCol, pcol, idT), nil
}
// Data is <PrivacyKey, PartitionKey>
return dropNonPublicPartitionsVFn(s, publicPartitionscCol, pcol), nil
}
// Drop non-public partitions, public partitions are specified as slice/array (i.e., in-memory).
// Convert PublicPartitions to map for quick lookup.
partitionEnc := beam.NewElementEncoder(partitionType)
partitionMap := pMap{}
for i := 0; i < reflect.ValueOf(publicPartitions).Len(); i++ {
partitionKey := reflect.ValueOf(publicPartitions).Index(i).Interface()
var partitionBuf bytes.Buffer
if err := partitionEnc.Encode(partitionKey, &partitionBuf); err != nil {
return pcol.col, fmt.Errorf("couldn't encode partition %v: %v", partitionKey, err)
}
partitionMap[base64.StdEncoding.EncodeToString(partitionBuf.Bytes())] = true
}
// Data is <PrivacyKey, PartitionKey, Value>
if pcol.codec != nil {
return beam.ParDo(s, newPrunePartitionsInMemoryKVFn(partitionMap), pcol.col), nil
}
// Data is <PrivacyKey, PartitionKey>
partitionEncodedType := beam.EncodedType{partitionType}
return beam.ParDo(s, newPrunePartitionsInMemoryVFn(partitionEncodedType, partitionMap), pcol.col), nil
}
// mergePublicValues merges the public partitions with the values for a PrivatePCollection
// after a CoGroupByKey. Only outputs a <privacyKey, v> pair (where v is value in the case
// of Count & DistinctPrivacyID, and kv.Pair for other aggregations) if the value is in
// the public partitions, i.e., the PCollection that is passed to the CoGroupByKey first.
func mergePublicValues(value beam.V, isKnown func(*int64) bool, privacyKeys func(*beam.U) bool, emit func(beam.U, beam.V)) {
var ignoredZero int64
if isKnown(&ignoredZero) {
var privacyKey beam.U
for privacyKeys(&privacyKey) {
emit(privacyKey, value)
}
}
}
// dropNonPublicPartitionsVFn drops partitions not specified in
// PublicPartitions from pcol. It can be used for aggregations on V values,
// e.g. Count and DistinctPrivacyID.
//
// We drop values that are not in the publicPartitions PCollection as follows:
// 1. Transform publicPartitions from <V> to <V, int64(0)> (0 is a placeholder value)
// 2. Swap pcol.col from <PrivacyKey, V> to <V, PrivacyKey>
// 3. Do a CoGroupByKey on the output of 1 and 2.
// 4. From the output of 3, only output <PrivacyKey, V> if there is an input
// from 1 using the mergePublicValues.
//
// Returns a PCollection<PrivacyKey, Value> only for values present in
// publicPartitions.
func dropNonPublicPartitionsVFn(s beam.Scope, publicPartitions beam.PCollection, pcol PrivatePCollection) beam.PCollection {
publicPartitionsWithZeros := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, publicPartitions)
groupedByValue := beam.CoGroupByKey(s, publicPartitionsWithZeros, beam.SwapKV(s, pcol.col))
return beam.ParDo(s, mergePublicValues, groupedByValue)
}
// dropNonPublicPartitionsKVFn drops partitions not specified in
// PublicPartitions from pcol. It can be used for aggregations on <K,V> pairs,
// e.g. SumPerKey and MeanPerKey.
//
// We drop values that are not in the publicPartitions PCollection as follows:
// 1. Transform publicPartitions from <PartitionKey> to <PartitionKey, int64(0)> (0 is a placeholder value)
// 2. Transform pcol.col from <PrivacyKey, <PartitionKey, Value>> to <PartitionKey, <PrivacyKey, Value>>
// 3. Do a CoGroupByKey on the output of 1 and 2.
// 4. From the output of 3, only output <PartitionKey, <PrivacyKey, Value>> if there
// is an input from 1 using mergePublicValues.
// 5. Transform output of 4 from <PartitionKey, <PrivacyKey, Value>> to <PrivacyKey, <PartitionKey, Value>>
//
// This works great for smaller partitions, but we run into performance bottlenecks in
// steps 3 & 4 in case some partitions have a huge number of user contributions.
//
// Returns a PCollection<PrivacyKey, <PartitionKey, Value>> only for values present in
// publicPartitions.
func dropNonPublicPartitionsKVFn(s beam.Scope, publicPartitions beam.PCollection, pcol PrivatePCollection, idType typex.FullType) beam.PCollection {
publicPartitionsWithZeros := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, publicPartitions)
encodedIDV := beam.ParDo(s, newEncodeIDVFn(idType, pcol.codec), pcol.col, beam.TypeDefinition{Var: beam.WType, T: pcol.codec.KType.T})
groupedByValue := beam.CoGroupByKey(s, publicPartitionsWithZeros, encodedIDV)
merged := beam.SwapKV(s, beam.ParDo(s, mergePublicValues, groupedByValue))
decodeFn := newDecodeIDVFn(pcol.codec.KType, kv.NewCodec(idType.Type(), pcol.codec.VType.T))
return beam.ParDo(s, decodeFn, merged, beam.TypeDefinition{Var: beam.UType, T: idType.Type()})
}
// encodeIDVFn takes a PCollection<ID,kv.Pair{K,V}> as input, and returns a
// PCollection<K, kv.Pair{ID,V}>; where ID and V have been coded, and K has been
// decoded.
type encodeIDVFn struct {
IDType beam.EncodedType // Type information of the privacy ID
idEnc beam.ElementEncoder // Encoder for privacy ID, set during Setup() according to IDType
InputPairCodec *kv.Codec // Codec for the input kv.Pair{K,V}
}
func newEncodeIDVFn(idType typex.FullType, kvCodec *kv.Codec) *encodeIDVFn {
return &encodeIDVFn{
IDType: beam.EncodedType{T: idType.Type()},
InputPairCodec: kvCodec,
}
}
func (fn *encodeIDVFn) Setup() error {
fn.idEnc = beam.NewElementEncoder(fn.IDType.T)
return fn.InputPairCodec.Setup()
}
func (fn *encodeIDVFn) ProcessElement(id beam.U, pair kv.Pair) (beam.W, kv.Pair, error) {
var idBuf bytes.Buffer
if err := fn.idEnc.Encode(id, &idBuf); err != nil {
return nil, kv.Pair{}, fmt.Errorf("pbeam.encodeIDVFn.ProcessElement: couldn't encode ID %v: %w", id, err)
}
k, _, err := fn.InputPairCodec.Decode(pair)
return k, kv.Pair{idBuf.Bytes(), pair.V}, err
}
// decodeIDVFn is the reverse operation of encodeIDVFn. It takes a PCollection<K, kv.Pair{ID,V}>
// as input, and returns a PCollection<ID, kv.Pair{K,V}>; where K and V has been coded, and ID
// has been decoded.
type decodeIDVFn struct {
KType beam.EncodedType // Type information of the partition key K
kEnc beam.ElementEncoder // Encoder for partition key, set during Setup() according to KType
InputPairCodec *kv.Codec // Codec for the input kv.Pair{ID,V}
}
func newDecodeIDVFn(kType beam.EncodedType, idvCodec *kv.Codec) *decodeIDVFn {
return &decodeIDVFn{
KType: kType,
InputPairCodec: idvCodec,
}
}
func (fn *decodeIDVFn) Setup() error {
fn.kEnc = beam.NewElementEncoder(fn.KType.T)
return fn.InputPairCodec.Setup()
}
func (fn *decodeIDVFn) ProcessElement(k beam.W, pair kv.Pair) (beam.U, kv.Pair, error) {
var kBuf bytes.Buffer
if err := fn.kEnc.Encode(k, &kBuf); err != nil {
return nil, kv.Pair{}, fmt.Errorf("pbeam.decodeIDVFn.ProcessElement: couldn't encode K %v: %w", k, err)
}
id, _, err := fn.InputPairCodec.Decode(pair)
return id, kv.Pair{kBuf.Bytes(), pair.V}, err // pair.V is the V in PCollection<K, kv.Pair{ID,V}>
}
// partitionMapFn makes a map consisting of public partitions.
type partitionMapFn struct {
PartitionType beam.EncodedType
partitionEnc beam.ElementEncoder
}
func newPartitionMapFn(partitionType beam.EncodedType) *partitionMapFn {
return &partitionMapFn{PartitionType: partitionType}
}
// Setup is our "constructor"
func (fn *partitionMapFn) Setup() {
fn.partitionEnc = beam.NewElementEncoder(fn.PartitionType.T)
}
// CreateAccumulator creates a new accumulator for the appropriate data type
func (fn *partitionMapFn) CreateAccumulator() pMap {
return make(pMap)
}
// AddInput adds the public partition key to the map
func (fn *partitionMapFn) AddInput(p pMap, partitionKey beam.W) (pMap, error) {
var partitionBuf bytes.Buffer
if err := fn.partitionEnc.Encode(partitionKey, &partitionBuf); err != nil {
return p, fmt.Errorf("pbeam.PartitionsMapFn.AddInput: couldn't encode partition key %v: %w", partitionKey, err)
}
p[base64.StdEncoding.EncodeToString(partitionBuf.Bytes())] = true
return p, nil
}
// MergeAccumulators adds the keys from a to b
func (fn *partitionMapFn) MergeAccumulators(a, b pMap) pMap {
for k := range a {
b[k] = true
}
return b
}
type prunePartitionsInMemoryVFn struct {
PartitionType beam.EncodedType
partitionEnc beam.ElementEncoder
PartitionMap pMap
}
func newPrunePartitionsInMemoryVFn(partitionType beam.EncodedType, partitionMap pMap) *prunePartitionsInMemoryVFn {
return &prunePartitionsInMemoryVFn{PartitionType: partitionType, PartitionMap: partitionMap}
}
func (fn *prunePartitionsInMemoryVFn) Setup() {
fn.partitionEnc = beam.NewElementEncoder(fn.PartitionType.T)
}
func (fn *prunePartitionsInMemoryVFn) ProcessElement(id beam.U, partitionKey beam.W, emit func(beam.U, beam.W)) error {
var partitionBuf bytes.Buffer
if err := fn.partitionEnc.Encode(partitionKey, &partitionBuf); err != nil {
return fmt.Errorf("pbeam.prunePartitionsInMemoryVFn.ProcessElement: couldn't encode partition %v: %w", partitionKey, err)
}
if fn.PartitionMap[base64.StdEncoding.EncodeToString(partitionBuf.Bytes())] {
emit(id, partitionKey)
}
return nil
}
type prunePartitionsInMemoryKVFn struct {
PartitionMap pMap
}
func newPrunePartitionsInMemoryKVFn(partitionMap pMap) *prunePartitionsInMemoryKVFn {
return &prunePartitionsInMemoryKVFn{PartitionMap: partitionMap}
}
func (fn *prunePartitionsInMemoryKVFn) ProcessElement(id beam.U, pair kv.Pair, emit func(beam.U, kv.Pair)) {
// Partition Key in a kv.Pair is already encoded, we just convert it to base64 encoding.
if fn.PartitionMap[base64.StdEncoding.EncodeToString(pair.K)] {
emit(id, pair)
}
}
// prunePartitionsFn takes a PCollection<ID, kv.Pair{K,V}> as input, and returns a
// PCollection<ID, kv.Pair{K,V}>, where non-public partitions have been dropped.
// Used for sum and mean.
func prunePartitionsKV(id beam.U, pair kv.Pair, partitionsIter func(*pMap) bool, emit func(beam.U, kv.Pair)) error {
var partitionMap pMap
partitionsIter(&partitionMap)
var err error
if partitionMap == nil {
return err
}
// Partition Key in a kv.Pair is already encoded, we just convert it to base64 encoding.
if partitionMap[base64.StdEncoding.EncodeToString(pair.K)] {
emit(id, pair)
}
return nil
}
func mergeResultWithEmptyPublicPartitionsFn(k beam.W, resultIter, publicPartitionsIter func(*beam.V) bool, emit func(beam.W, beam.V)) {
var v beam.V
if resultIter(&v) {
emit(k, v)
} else {
if publicPartitionsIter(&v) {
emit(k, v)
}
}
}
// checkPublicPartitions returns an error if publicPartitions parameter of an aggregation
// is not valid.
func checkPublicPartitions(publicPartitions any, partitionType reflect.Type) error {
if publicPartitions != nil {
if reflect.TypeOf(publicPartitions) != reflect.TypeOf(beam.PCollection{}) &&
reflect.ValueOf(publicPartitions).Kind() != reflect.Slice &&
reflect.ValueOf(publicPartitions).Kind() != reflect.Array {
return fmt.Errorf("PublicPartitions=%+v needs to be a beam.PCollection, slice or array", reflect.TypeOf(publicPartitions))
}
publicPartitionsCol, isPCollection := publicPartitions.(beam.PCollection)
if isPCollection && (!publicPartitionsCol.IsValid() || partitionType != publicPartitionsCol.Type().Type()) {
return fmt.Errorf("PublicPartitions=%+v needs to be a valid beam.PCollection with the same type as the partition key (+%v)", publicPartitions, partitionType)
}
if !isPCollection && reflect.TypeOf(publicPartitions).Elem() != partitionType {
return fmt.Errorf("PublicPartitions=%+v needs to be a slice or an array whose elements are the same type as the partition key (%+v)", publicPartitions, partitionType)
}
}
return nil
}