| // |
| // Copyright 2020 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| |
| // This file contains logic for handling public partitions. |
| |
| package pbeam |
| |
| import ( |
| "bytes" |
| "encoding/base64" |
| "fmt" |
| "math/rand" |
| "reflect" |
| |
| "flag" |
| "github.com/google/differential-privacy/privacy-on-beam/v3/internal/kv" |
| "github.com/apache/beam/sdks/v2/go/pkg/beam" |
| "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" |
| "github.com/apache/beam/sdks/v2/go/pkg/beam/register" |
| "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/filter" |
| ) |
| |
| var ( |
| enableShardedPublicPartitions = flag.Bool("enable_sharded_public_partitions", false, "Enable sharded public partitions. This is a temporary flag to allow us to test the new sharded implementation of public partition filtering.") |
| ) |
| |
| func init() { |
| register.Combiner2[pMap, beam.W](&partitionMapFn{}) |
| |
| register.DoFn2x3[beam.U, kv.Pair, beam.W, kv.Pair, error](&encodeIDVFn{}) |
| register.DoFn2x3[beam.W, kv.Pair, beam.U, kv.Pair, error](&decodeIDVFn{}) |
| register.DoFn3x1[beam.U, beam.W, func(beam.U, beam.W), error](&prunePartitionsInMemoryVFn{}) |
| register.Emitter2[beam.U, beam.W]() |
| register.DoFn3x0[beam.U, kv.Pair, func(beam.U, kv.Pair)](&prunePartitionsInMemoryKVFn{}) |
| register.Emitter2[beam.U, kv.Pair]() |
| |
| register.Function1x2[beam.W, beam.W, int64](addZeroValuesToPublicPartitionsInt64) |
| register.Function1x2[beam.W, beam.W, float64](addZeroValuesToPublicPartitionsFloat64) |
| register.Function1x2[beam.W, beam.W, []float64](addEmptySliceToPublicPartitionsFloat64) |
| register.Function4x1[beam.U, kv.Pair, func(*pMap) bool, func(beam.U, kv.Pair), error](prunePartitionsKV) |
| register.Function4x0[beam.V, func(*int64) bool, func(*beam.U) bool, func(beam.U, beam.V)](mergePublicValues) |
| register.Iter1[int64]() |
| register.Iter1[beam.U]() |
| register.Emitter2[beam.U, beam.V]() |
| register.Function4x0[beam.W, func(*beam.V) bool, func(*beam.V) bool, func(beam.W, beam.V)](mergeResultWithEmptyPublicPartitionsFn) |
| register.Iter1[beam.V]() |
| register.Emitter2[beam.W, beam.V]() |
| |
| register.Function1x2[kv.Pair, ShardedKey, []byte](addRandomShardIDFn) |
| register.Function4x0[[]byte, func(*[]byte) bool, func(*int) bool, func(ShardedKey, int)](extractEmittableKeysWithShardIDFn) |
| register.DoFn4x1[ShardedKey, func(*int) bool, func(*[]byte) bool, func(beam.T, beam.W), error](&filterKeysWithShardIDFn{}) |
| register.Function1x2[kv.Pair, []byte, []byte](unwrapPairFn) |
| } |
| |
| // newAddZeroValuesToPublicPartitionsFn turns a PCollection<V> into PCollection<V,0>. |
| func newAddZeroValuesToPublicPartitionsFn(vKind reflect.Kind) (any, error) { |
| switch vKind { |
| case reflect.Int64: |
| return addZeroValuesToPublicPartitionsInt64, nil |
| case reflect.Float64: |
| return addZeroValuesToPublicPartitionsFloat64, nil |
| default: |
| return nil, fmt.Errorf("vKind(%v) should be int64 or float64", vKind) |
| } |
| } |
| |
| func addZeroValuesToPublicPartitionsInt64(partition beam.W) (k beam.W, v int64) { |
| return partition, 0 |
| } |
| |
| func addZeroValuesToPublicPartitionsFloat64(partition beam.W) (k beam.W, v float64) { |
| return partition, 0 |
| } |
| |
| func addEmptySliceToPublicPartitionsFloat64(partition beam.W) (k beam.W, v []float64) { |
| return partition, []float64{} |
| } |
| |
| // pMap holds a set of partition keys for quick lookup as a map from string to bool. |
| // Key is the base64 string representation of encoded partition key. |
| // Value is set to true if partition key exists. |
| type pMap map[string]bool |
| |
| // dropNonPublicPartitions returns the PCollection with the non-public partitions dropped if public partitions are |
| // specified. Returns the input PCollection otherwise. |
| func dropNonPublicPartitions(s beam.Scope, pcol PrivatePCollection, publicPartitions any, partitionType reflect.Type) (beam.PCollection, error) { |
| // Obtain type information from the underlying PCollection<K,V>. |
| idT, _ := beam.ValidateKVType(pcol.col) |
| |
| // If PublicPartitions is not specified, return the input collection. |
| if publicPartitions == nil { |
| return pcol.col, nil |
| } |
| |
| // Drop non-public partitions, if public partitions are specified as a PCollection. |
| if publicPartitionscCol, ok := publicPartitions.(beam.PCollection); ok { |
| // Data is <PrivacyKey, PartitionKey, Value> |
| if pcol.codec != nil { |
| return dropNonPublicPartitionsKVFn(s, publicPartitionscCol, pcol, idT), nil |
| } |
| // Data is <PrivacyKey, PartitionKey> |
| return dropNonPublicPartitionsVFn(s, publicPartitionscCol, pcol), nil |
| } |
| |
| // Drop non-public partitions, public partitions are specified as slice/array (i.e., in-memory). |
| // Convert PublicPartitions to map for quick lookup. |
| partitionEnc := beam.NewElementEncoder(partitionType) |
| partitionMap := pMap{} |
| for i := 0; i < reflect.ValueOf(publicPartitions).Len(); i++ { |
| partitionKey := reflect.ValueOf(publicPartitions).Index(i).Interface() |
| var partitionBuf bytes.Buffer |
| if err := partitionEnc.Encode(partitionKey, &partitionBuf); err != nil { |
| return pcol.col, fmt.Errorf("couldn't encode partition %v: %v", partitionKey, err) |
| } |
| partitionMap[base64.StdEncoding.EncodeToString(partitionBuf.Bytes())] = true |
| } |
| // Data is <PrivacyKey, PartitionKey, Value> |
| if pcol.codec != nil { |
| return beam.ParDo(s, newPrunePartitionsInMemoryKVFn(partitionMap), pcol.col), nil |
| } |
| // Data is <PrivacyKey, PartitionKey> |
| partitionEncodedType := beam.EncodedType{partitionType} |
| return beam.ParDo(s, newPrunePartitionsInMemoryVFn(partitionEncodedType, partitionMap), pcol.col), nil |
| } |
| |
| // mergePublicValues merges the public partitions with the values for a PrivatePCollection |
| // after a CoGroupByKey. Only outputs a <privacyKey, v> pair (where v is value in the case |
| // of Count & DistinctPrivacyID, and kv.Pair for other aggregations) if the value is in |
| // the public partitions, i.e., the PCollection that is passed to the CoGroupByKey first. |
| func mergePublicValues(value beam.V, isKnown func(*int64) bool, privacyKeys func(*beam.U) bool, emit func(beam.U, beam.V)) { |
| var ignoredZero int64 |
| if isKnown(&ignoredZero) { |
| var privacyKey beam.U |
| for privacyKeys(&privacyKey) { |
| emit(privacyKey, value) |
| } |
| } |
| } |
| |
| // ShardedKey is an key encoded as bytes with a int shardID. |
| type ShardedKey struct { |
| K []byte |
| ShardID int |
| } |
| |
| func addRandomShardIDFn(encoded kv.Pair) (ShardedKey, []byte) { |
| return ShardedKey{K: encoded.K, ShardID: rand.Intn(2048)}, encoded.V |
| } |
| |
| func extractEmittableKeysWithShardIDFn(k []byte, isAllowedKeyIter func(*[]byte) bool, shardIDIter func(*int) bool, emit func(ShardedKey, int)) { |
| var isAllowedKey []byte |
| if !isAllowedKeyIter(&isAllowedKey) { |
| // k is not an allow listed key, filter it out. |
| return |
| } |
| |
| var subkey int |
| for shardIDIter(&subkey) { |
| emit(ShardedKey{K: k, ShardID: subkey}, 0) |
| } |
| } |
| |
| type filterKeysWithShardIDFn struct { |
| KType beam.EncodedType |
| VType beam.EncodedType |
| PairCodec *kv.Codec |
| } |
| |
| func newFilterKeysWithShardIDFn(kType, vType reflect.Type) *filterKeysWithShardIDFn { |
| return &filterKeysWithShardIDFn{ |
| KType: beam.EncodedType{T: kType}, |
| VType: beam.EncodedType{T: vType}, |
| } |
| } |
| |
| func (fn *filterKeysWithShardIDFn) Setup() error { |
| fn.PairCodec = kv.NewCodec(fn.KType.T, fn.VType.T) |
| return fn.PairCodec.Setup() |
| } |
| |
| func (fn *filterKeysWithShardIDFn) ProcessElement(k ShardedKey, isEmittableShardIDIter func(*int) bool, pcolValueIter func(*[]byte) bool, emit func(beam.T, beam.W)) error { |
| var isEmittableShardID int |
| if !isEmittableShardIDIter(&isEmittableShardID) { |
| // k isn't a key from the public partitions collection. |
| return nil |
| } |
| |
| var pcolValue []byte |
| for pcolValueIter(&pcolValue) { |
| k, v, err := fn.PairCodec.Decode(kv.Pair{K: k.K, V: pcolValue}) |
| if err != nil { |
| return err |
| } |
| emit(k, v) |
| } |
| |
| return nil |
| } |
| |
| func unwrapPairFn(encoded kv.Pair) ([]byte, []byte) { |
| return encoded.K, encoded.V |
| } |
| |
| func unwrapShardedKeyFn(shardedKey ShardedKey) ([]byte, int) { |
| return shardedKey.K, shardedKey.ShardID |
| } |
| |
| // Filters out KV-s from col that have a key in 'keys'. |
| // |
| // A single key in col may have a huge number of values. This function handles that |
| // case by sharding col randomly into 2048 collections, and then joining the keys |
| // within each of those shards. This reduces stragglers when there is a hot key in col, |
| // since its processing is parallelized 2048 ways. |
| // |
| // Each value in col is randomly selected to be in one of the 2048 shards. Then the |
| // 'keys' collection is joined with the sharded col collection to find the shardIds |
| // that need to be present per key. Then the sharded col collection is joined with |
| // the sharded 'keys' collection to compute the final filtered result. |
| func filterKeysImbalanced(s beam.Scope, col beam.PCollection, keys beam.PCollection) beam.PCollection { |
| s = s.Scope("filterKeysImbalanced") |
| |
| kT, vT := beam.ValidateKVType(col) |
| |
| // Add a random shardId (one of 2048) to each element in col. |
| // PCollection<KV<ShardedKey[key, randInt[0;2048)], []byte>> |
| pcolAsBytesWithShardID := beam.ParDo(s, addRandomShardIDFn, beam.ParDo(s, kv.NewEncodeFn(kT, vT), col)) |
| // Drop values and remove duplicates. |
| // PCollection<ShardedKey[key, randInt[0;2048)]> |
| uniqueKeysWithShardID := filter.Distinct(s, beam.DropValue(s, pcolAsBytesWithShardID)) |
| |
| // Prepare the keys for a CoGroupBy with uniqueSubkeys. |
| // PCollection<KV<key, 0>> |
| keysWithZero := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, keys) |
| // PCollection<kv.Pair<key, 0>> |
| keysEncodedWithZero := beam.ParDo( |
| s, |
| kv.NewEncodeFn(keysWithZero.Type().Components()[0], |
| keysWithZero.Type().Components()[1]), keysWithZero) |
| |
| // Find the shardIds per key in the keys collection. |
| // PCollection<KV<key, *>> |
| groupedByKey := beam.CoGroupByKey(s, beam.ParDo(s, unwrapPairFn, keysEncodedWithZero), beam.ParDo(s, unwrapShardedKeyFn, uniqueKeysWithShardID)) |
| emittableKeysWithShardID := beam.ParDo(s, extractEmittableKeysWithShardIDFn, groupedByKey) |
| |
| // Finally perform the sharded filter. |
| groupedByKeyAndShardID := beam.CoGroupByKey(s, emittableKeysWithShardID, pcolAsBytesWithShardID) |
| return beam.ParDo(s, newFilterKeysWithShardIDFn(kT.Type(), vT.Type()), groupedByKeyAndShardID, beam.TypeDefinition{Var: beam.TType, T: kT.Type()}, beam.TypeDefinition{Var: beam.WType, T: vT.Type()}) |
| } |
| |
| // dropNonPublicPartitionsVFn drops partitions not specified in |
| // PublicPartitions from pcol. It can be used for aggregations on V values, |
| // e.g. Count and DistinctPrivacyID. |
| // |
| // We drop values that are not in the publicPartitions PCollection as follows: |
| // 1. Transform publicPartitions from <V> to <V, int64(0)> (0 is a placeholder value) |
| // 2. Swap pcol.col from <PrivacyKey, V> to <V, PrivacyKey> |
| // 3. Do a CoGroupByKey on the output of 1 and 2. |
| // 4. From the output of 3, only output <PrivacyKey, V> if there is an input |
| // from 1 using the mergePublicValues. |
| // |
| // Returns a PCollection<PrivacyKey, Value> only for values present in |
| // publicPartitions. |
| func dropNonPublicPartitionsVFn(s beam.Scope, publicPartitions beam.PCollection, pcol PrivatePCollection) beam.PCollection { |
| if *enableShardedPublicPartitions { |
| return beam.SwapKV(s, filterKeysImbalanced(s, beam.SwapKV(s, pcol.col), publicPartitions)) |
| } else { |
| publicPartitionsWithZeros := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, publicPartitions) |
| groupedByValue := beam.CoGroupByKey(s, publicPartitionsWithZeros, beam.SwapKV(s, pcol.col)) |
| return beam.ParDo(s, mergePublicValues, groupedByValue) |
| } |
| } |
| |
| // dropNonPublicPartitionsKVFn drops partitions not specified in |
| // PublicPartitions from pcol. It can be used for aggregations on <K,V> pairs, |
| // e.g. SumPerKey and MeanPerKey. |
| // |
| // We drop values that are not in the publicPartitions PCollection as follows: |
| // 1. Transform publicPartitions from <PartitionKey> to <PartitionKey, int64(0)> (0 is a placeholder value) |
| // 2. Transform pcol.col from <PrivacyKey, <PartitionKey, Value>> to <PartitionKey, <PrivacyKey, Value>> |
| // 3. Do a CoGroupByKey on the output of 1 and 2. |
| // 4. From the output of 3, only output <PartitionKey, <PrivacyKey, Value>> if there |
| // is an input from 1 using mergePublicValues. |
| // 5. Transform output of 4 from <PartitionKey, <PrivacyKey, Value>> to <PrivacyKey, <PartitionKey, Value>> |
| // |
| // This works great for smaller partitions, but we run into performance bottlenecks in |
| // steps 3 & 4 in case some partitions have a huge number of user contributions. |
| // |
| // Returns a PCollection<PrivacyKey, <PartitionKey, Value>> only for values present in |
| // publicPartitions. |
| func dropNonPublicPartitionsKVFn(s beam.Scope, publicPartitions beam.PCollection, pcol PrivatePCollection, idType typex.FullType) beam.PCollection { |
| if *enableShardedPublicPartitions { |
| encodedIDV := beam.ParDo( |
| s, newEncodeIDVFn(idType, pcol.codec), pcol.col, beam.TypeDefinition{Var: beam.WType, T: pcol.codec.KType.T}) |
| filteredEncodedIDV := filterKeysImbalanced(s, encodedIDV, publicPartitions) |
| decodeFn := newDecodeIDVFn(pcol.codec.KType, kv.NewCodec(idType.Type(), pcol.codec.VType.T)) |
| return beam.ParDo(s, decodeFn, filteredEncodedIDV, beam.TypeDefinition{Var: beam.UType, T: idType.Type()}) |
| } else { |
| publicPartitionsWithZeros := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, publicPartitions) |
| encodedIDV := beam.ParDo(s, newEncodeIDVFn(idType, pcol.codec), pcol.col, beam.TypeDefinition{Var: beam.WType, T: pcol.codec.KType.T}) |
| groupedByValue := beam.CoGroupByKey(s, publicPartitionsWithZeros, encodedIDV) |
| merged := beam.SwapKV(s, beam.ParDo(s, mergePublicValues, groupedByValue)) |
| decodeFn := newDecodeIDVFn(pcol.codec.KType, kv.NewCodec(idType.Type(), pcol.codec.VType.T)) |
| return beam.ParDo(s, decodeFn, merged, beam.TypeDefinition{Var: beam.UType, T: idType.Type()}) |
| } |
| } |
| |
| // encodeIDVFn takes a PCollection<ID,kv.Pair{K,V}> as input, and returns a |
| // PCollection<K, kv.Pair{ID,V}>; where ID and V have been coded, and K has been |
| // decoded. |
| type encodeIDVFn struct { |
| IDType beam.EncodedType // Type information of the privacy ID |
| idEnc beam.ElementEncoder // Encoder for privacy ID, set during Setup() according to IDType |
| InputPairCodec *kv.Codec // Codec for the input kv.Pair{K,V} |
| } |
| |
| func newEncodeIDVFn(idType typex.FullType, kvCodec *kv.Codec) *encodeIDVFn { |
| return &encodeIDVFn{ |
| IDType: beam.EncodedType{T: idType.Type()}, |
| InputPairCodec: kvCodec, |
| } |
| } |
| |
| func (fn *encodeIDVFn) Setup() error { |
| fn.idEnc = beam.NewElementEncoder(fn.IDType.T) |
| return fn.InputPairCodec.Setup() |
| } |
| |
| func (fn *encodeIDVFn) ProcessElement(id beam.U, pair kv.Pair) (beam.W, kv.Pair, error) { |
| var idBuf bytes.Buffer |
| if err := fn.idEnc.Encode(id, &idBuf); err != nil { |
| return nil, kv.Pair{}, fmt.Errorf("pbeam.encodeIDVFn.ProcessElement: couldn't encode ID %v: %w", id, err) |
| } |
| k, _, err := fn.InputPairCodec.Decode(pair) |
| return k, kv.Pair{idBuf.Bytes(), pair.V}, err |
| } |
| |
| // decodeIDVFn is the reverse operation of encodeIDVFn. It takes a PCollection<K, kv.Pair{ID,V}> |
| // as input, and returns a PCollection<ID, kv.Pair{K,V}>; where K and V has been coded, and ID |
| // has been decoded. |
| type decodeIDVFn struct { |
| KType beam.EncodedType // Type information of the partition key K |
| kEnc beam.ElementEncoder // Encoder for partition key, set during Setup() according to KType |
| InputPairCodec *kv.Codec // Codec for the input kv.Pair{ID,V} |
| } |
| |
| func newDecodeIDVFn(kType beam.EncodedType, idvCodec *kv.Codec) *decodeIDVFn { |
| return &decodeIDVFn{ |
| KType: kType, |
| InputPairCodec: idvCodec, |
| } |
| } |
| |
| func (fn *decodeIDVFn) Setup() error { |
| fn.kEnc = beam.NewElementEncoder(fn.KType.T) |
| return fn.InputPairCodec.Setup() |
| } |
| |
| func (fn *decodeIDVFn) ProcessElement(k beam.W, pair kv.Pair) (beam.U, kv.Pair, error) { |
| var kBuf bytes.Buffer |
| if err := fn.kEnc.Encode(k, &kBuf); err != nil { |
| return nil, kv.Pair{}, fmt.Errorf("pbeam.decodeIDVFn.ProcessElement: couldn't encode K %v: %w", k, err) |
| } |
| id, _, err := fn.InputPairCodec.Decode(pair) |
| return id, kv.Pair{kBuf.Bytes(), pair.V}, err // pair.V is the V in PCollection<K, kv.Pair{ID,V}> |
| } |
| |
| // partitionMapFn makes a map consisting of public partitions. |
| type partitionMapFn struct { |
| PartitionType beam.EncodedType |
| partitionEnc beam.ElementEncoder |
| } |
| |
| func newPartitionMapFn(partitionType beam.EncodedType) *partitionMapFn { |
| return &partitionMapFn{PartitionType: partitionType} |
| } |
| |
| // Setup is our "constructor" |
| func (fn *partitionMapFn) Setup() { |
| fn.partitionEnc = beam.NewElementEncoder(fn.PartitionType.T) |
| } |
| |
| // CreateAccumulator creates a new accumulator for the appropriate data type |
| func (fn *partitionMapFn) CreateAccumulator() pMap { |
| return make(pMap) |
| } |
| |
| // AddInput adds the public partition key to the map |
| func (fn *partitionMapFn) AddInput(p pMap, partitionKey beam.W) (pMap, error) { |
| var partitionBuf bytes.Buffer |
| if err := fn.partitionEnc.Encode(partitionKey, &partitionBuf); err != nil { |
| return p, fmt.Errorf("pbeam.PartitionsMapFn.AddInput: couldn't encode partition key %v: %w", partitionKey, err) |
| } |
| p[base64.StdEncoding.EncodeToString(partitionBuf.Bytes())] = true |
| return p, nil |
| } |
| |
| // MergeAccumulators adds the keys from a to b |
| func (fn *partitionMapFn) MergeAccumulators(a, b pMap) pMap { |
| for k := range a { |
| b[k] = true |
| } |
| return b |
| } |
| |
| type prunePartitionsInMemoryVFn struct { |
| PartitionType beam.EncodedType |
| partitionEnc beam.ElementEncoder |
| PartitionMap pMap |
| } |
| |
| func newPrunePartitionsInMemoryVFn(partitionType beam.EncodedType, partitionMap pMap) *prunePartitionsInMemoryVFn { |
| return &prunePartitionsInMemoryVFn{PartitionType: partitionType, PartitionMap: partitionMap} |
| } |
| |
| func (fn *prunePartitionsInMemoryVFn) Setup() { |
| fn.partitionEnc = beam.NewElementEncoder(fn.PartitionType.T) |
| } |
| |
| func (fn *prunePartitionsInMemoryVFn) ProcessElement(id beam.U, partitionKey beam.W, emit func(beam.U, beam.W)) error { |
| var partitionBuf bytes.Buffer |
| if err := fn.partitionEnc.Encode(partitionKey, &partitionBuf); err != nil { |
| return fmt.Errorf("pbeam.prunePartitionsInMemoryVFn.ProcessElement: couldn't encode partition %v: %w", partitionKey, err) |
| } |
| if fn.PartitionMap[base64.StdEncoding.EncodeToString(partitionBuf.Bytes())] { |
| emit(id, partitionKey) |
| } |
| return nil |
| } |
| |
| type prunePartitionsInMemoryKVFn struct { |
| PartitionMap pMap |
| } |
| |
| func newPrunePartitionsInMemoryKVFn(partitionMap pMap) *prunePartitionsInMemoryKVFn { |
| return &prunePartitionsInMemoryKVFn{PartitionMap: partitionMap} |
| } |
| |
| func (fn *prunePartitionsInMemoryKVFn) ProcessElement(id beam.U, pair kv.Pair, emit func(beam.U, kv.Pair)) { |
| // Partition Key in a kv.Pair is already encoded, we just convert it to base64 encoding. |
| if fn.PartitionMap[base64.StdEncoding.EncodeToString(pair.K)] { |
| emit(id, pair) |
| } |
| } |
| |
| // prunePartitionsFn takes a PCollection<ID, kv.Pair{K,V}> as input, and returns a |
| // PCollection<ID, kv.Pair{K,V}>, where non-public partitions have been dropped. |
| // Used for sum and mean. |
| func prunePartitionsKV(id beam.U, pair kv.Pair, partitionsIter func(*pMap) bool, emit func(beam.U, kv.Pair)) error { |
| var partitionMap pMap |
| partitionsIter(&partitionMap) |
| var err error |
| if partitionMap == nil { |
| return err |
| } |
| // Partition Key in a kv.Pair is already encoded, we just convert it to base64 encoding. |
| if partitionMap[base64.StdEncoding.EncodeToString(pair.K)] { |
| emit(id, pair) |
| } |
| return nil |
| } |
| |
| func mergeResultWithEmptyPublicPartitionsFn(k beam.W, resultIter, publicPartitionsIter func(*beam.V) bool, emit func(beam.W, beam.V)) { |
| var v beam.V |
| if resultIter(&v) { |
| emit(k, v) |
| } else { |
| if publicPartitionsIter(&v) { |
| emit(k, v) |
| } |
| } |
| } |
| |
| // checkPublicPartitions returns an error if publicPartitions parameter of an aggregation |
| // is not valid. |
| func checkPublicPartitions(publicPartitions any, partitionType reflect.Type) error { |
| if publicPartitions != nil { |
| if reflect.TypeOf(publicPartitions) != reflect.TypeOf(beam.PCollection{}) && |
| reflect.ValueOf(publicPartitions).Kind() != reflect.Slice && |
| reflect.ValueOf(publicPartitions).Kind() != reflect.Array { |
| return fmt.Errorf("PublicPartitions=%+v needs to be a beam.PCollection, slice or array", reflect.TypeOf(publicPartitions)) |
| } |
| publicPartitionsCol, isPCollection := publicPartitions.(beam.PCollection) |
| if isPCollection && (!publicPartitionsCol.IsValid() || partitionType != publicPartitionsCol.Type().Type()) { |
| return fmt.Errorf("PublicPartitions=%+v needs to be a valid beam.PCollection with the same type as the partition key (+%v)", publicPartitions, partitionType) |
| } |
| if !isPCollection && reflect.TypeOf(publicPartitions).Elem() != partitionType { |
| return fmt.Errorf("PublicPartitions=%+v needs to be a slice or an array whose elements are the same type as the partition key (%+v)", publicPartitions, partitionType) |
| } |
| } |
| return nil |
| } |