Add Spark to PipelineDP4j, sharded key filtering to PBeam
Export of internal changes.
--
5478b1957ce80be498abac882a75dd958e1b966e by Differential Privacy Team <noreply@google.com>:
Introduce AboveThresholdSelector to StreamingPartitionSelector
PiperOrigin-RevId: 700987846
Change-Id: Ia6b8f8b2d5da936a3d6229c180d59dc176a89970
--
050bc0203c0cb5d07e67bd7e0dab20b348ddd564 by Differential Privacy Team <noreply@google.com>:
n/a
PiperOrigin-RevId: 700310388
Change-Id: Ib0867257ab65acc7bfc37b20aac49211ca3bfec4
GitOrigin-RevId: 11fd989e0ed941bca0f189400d54ff464c4f4238
diff --git a/.gitignore b/.gitignore
index 9d1acbb..d883a6d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,12 +2,3 @@
**/bazel-java
**/bazel-out
**/bazel-testlogs
-
-**/bazel-differential-privacy
-**/.ijwb/
-**/pipelinedp4j/.ijwb/
-**/pipelinedp4j/bazel-pipelinedp4j
-**/pipelinedp4j/MODULE**
-**/examples/.idea/
-**/examples/pipelinedp4j/bazel-pipelinedp4j
-**/examples/pipelinedp4j/MODULE**
diff --git a/cc/algorithms/BUILD b/cc/algorithms/BUILD
index 2ab7d58..764e780 100644
--- a/cc/algorithms/BUILD
+++ b/cc/algorithms/BUILD
@@ -139,7 +139,6 @@
deps = [
":algorithm",
":approx-bounds",
- ":bounded-algorithm",
":numerical-mechanisms",
":util",
"//proto:util-lib",
@@ -245,7 +244,6 @@
deps = [
":algorithm",
":approx-bounds",
- ":bounded-algorithm",
":numerical-mechanisms",
":util",
"//proto:util-lib",
@@ -291,10 +289,8 @@
deps = [
":algorithm",
":approx-bounds",
- ":bounded-algorithm",
":bounded-variance",
":numerical-mechanisms",
- ":util",
"//proto:util-lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
@@ -534,6 +530,7 @@
cc_library(
name = "bounded-algorithm",
hdrs = ["bounded-algorithm.h"],
+ visibility = ["//visibility:private"],
deps = [
":algorithm",
":approx-bounds",
@@ -668,10 +665,14 @@
visibility = ["//visibility:public"],
deps = [
":algorithm",
- ":bounded-algorithm",
+ ":numerical-mechanisms",
":quantile-tree",
+ ":util",
+ "@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_cc_differential_privacy//base:status_macros",
],
)
diff --git a/cc/algorithms/bounded-standard-deviation.h b/cc/algorithms/bounded-standard-deviation.h
index 31676dd..96fad76 100644
--- a/cc/algorithms/bounded-standard-deviation.h
+++ b/cc/algorithms/bounded-standard-deviation.h
@@ -28,10 +28,8 @@
#include "absl/status/statusor.h"
#include "algorithms/algorithm.h"
#include "algorithms/approx-bounds.h"
-#include "algorithms/bounded-algorithm.h"
#include "algorithms/bounded-variance.h"
#include "algorithms/numerical-mechanisms.h"
-#include "algorithms/util.h"
#include "proto/util.h"
#include "proto/data.pb.h"
#include "proto/summary.pb.h"
diff --git a/cc/algorithms/bounded-sum.h b/cc/algorithms/bounded-sum.h
index 784ad2c..5de00b2 100644
--- a/cc/algorithms/bounded-sum.h
+++ b/cc/algorithms/bounded-sum.h
@@ -31,13 +31,11 @@
#include "google/protobuf/any.pb.h"
#include "absl/log/log.h"
-#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "algorithms/algorithm.h"
#include "algorithms/approx-bounds.h"
-#include "algorithms/bounded-algorithm.h"
#include "algorithms/numerical-mechanisms.h"
#include "algorithms/util.h"
#include "proto/util.h"
diff --git a/cc/algorithms/bounded-variance.h b/cc/algorithms/bounded-variance.h
index ad4080e..b298977 100644
--- a/cc/algorithms/bounded-variance.h
+++ b/cc/algorithms/bounded-variance.h
@@ -37,7 +37,6 @@
#include "absl/strings/str_cat.h"
#include "algorithms/algorithm.h"
#include "algorithms/approx-bounds.h"
-#include "algorithms/bounded-algorithm.h"
#include "algorithms/numerical-mechanisms.h"
#include "algorithms/util.h"
#include "proto/util.h"
diff --git a/cc/algorithms/quantiles.h b/cc/algorithms/quantiles.h
index 119bb73..fbc2bf2 100644
--- a/cc/algorithms/quantiles.h
+++ b/cc/algorithms/quantiles.h
@@ -18,13 +18,21 @@
#define DIFFERENTIAL_PRIVACY_CPP_ALGORITHMS_QUANTILES_H_
#include <cstdint>
+#include <memory>
#include <optional>
+#include <type_traits>
+#include <utility>
+#include <vector>
+#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
#include "algorithms/algorithm.h"
-#include "algorithms/bounded-algorithm.h"
+#include "algorithms/numerical-mechanisms.h"
#include "algorithms/quantile-tree.h"
+#include "algorithms/util.h"
+#include "base/status_macros.h"
namespace differential_privacy {
@@ -239,7 +247,7 @@
int max_partitions_contributed_ = 1;
int max_contributions_per_partition_ = 1;
std::unique_ptr<NumericalMechanismBuilder> mechanism_builder_ =
- absl::make_unique<LaplaceMechanism::Builder>();
+ std::make_unique<LaplaceMechanism::Builder>();
std::vector<double> quantiles_;
static absl::Status ValidateQuantiles(std::vector<double>& quantiles) {
diff --git a/cc/algorithms/util.cc b/cc/algorithms/util.cc
index aadcdff..f63b0b9 100644
--- a/cc/algorithms/util.cc
+++ b/cc/algorithms/util.cc
@@ -303,6 +303,10 @@
"contributed to (i.e., L0 sensitivity)");
}
+absl::Status ValidateMaxWindows(std::optional<int> max_windows) {
+ return ValidateIsPositive(max_windows, "Maximum number of windows");
+}
+
absl::Status ValidateMaxContributionsPerPartition(
std::optional<double> max_contributions_per_partition) {
return ValidateIsPositive(max_contributions_per_partition,
diff --git a/cc/algorithms/util.h b/cc/algorithms/util.h
index 39001ca..389016f 100644
--- a/cc/algorithms/util.h
+++ b/cc/algorithms/util.h
@@ -522,6 +522,7 @@
absl::Status ValidateDelta(std::optional<double> delta);
absl::Status ValidateMaxPartitionsContributed(
std::optional<double> max_partitions_contributed);
+absl::Status ValidateMaxWindows(std::optional<int> max_windows);
absl::Status ValidateMaxContributionsPerPartition(
std::optional<double> max_contributions_per_partition);
absl::Status ValidateMaxContributions(std::optional<int> max_contributions);
diff --git a/examples/pipelinedp4j/README.md b/examples/pipelinedp4j/README.md
index b95245f..df1f718 100644
--- a/examples/pipelinedp4j/README.md
+++ b/examples/pipelinedp4j/README.md
@@ -90,7 +90,7 @@
1. Run the program:
```shell
- bazel-bin/BeamExample --inputFilePath=netflix_data.csv --outputFilePath=output.txt
+ bazel-bin/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BeamExample --inputFilePath=netflix_data.csv --outputFilePath=output.txt
```
1. View the results:
diff --git a/examples/pipelinedp4j/pom.xml b/examples/pipelinedp4j/pom.xml
index 4aaed76..af5066c 100644
--- a/examples/pipelinedp4j/pom.xml
+++ b/examples/pipelinedp4j/pom.xml
@@ -46,6 +46,10 @@
<configuration>
<source>11</source>
<target>11</target>
+ <excludes>
+ <!-- SparkExample is not supoprted yet. -->
+ <exclude>com/google/privacy/differentialprivacy/pipelinedp4j/examples/SparkExample.java</exclude>
+ </excludes>
</configuration>
</plugin>
</plugins>
diff --git a/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BUILD.bazel b/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BUILD.bazel
index a731aae..26f721e 100644
--- a/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BUILD.bazel
+++ b/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BUILD.bazel
@@ -37,23 +37,22 @@
java_binary(
name = "SparkExample",
srcs = [
- "SparkExample.java",
"MovieMetrics.java",
"MovieView.java",
+ "SparkExample.java",
],
main_class = "com.google.privacy.differentialprivacy.pipelinedp4j.examples.SparkExample",
deps = [
"@com_google_privacy_differentialprivacy_pipielinedp4j//main/com/google/privacy/differentialprivacy/pipelinedp4j/api",
- "@maven//:com_google_guava_guava",
- "@maven//:info_picocli_picocli",
- "@maven//:org_jetbrains_kotlin_kotlin_stdlib",
-
- "@maven//:org_apache_spark_spark_core_2_13",
- "@maven//:org_apache_spark_spark_sql_2_13",
- "@maven//:org_apache_spark_spark_mllib_2_13",
- "@maven//:org_apache_spark_spark_catalyst_2_13",
"@maven//:com_fasterxml_jackson_core_jackson_databind",
"@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
+ "@maven//:com_google_guava_guava",
+ "@maven//:info_picocli_picocli",
+ "@maven//:org_apache_spark_spark_catalyst_2_13",
+ "@maven//:org_apache_spark_spark_core_2_13",
+ "@maven//:org_apache_spark_spark_mllib_2_13",
+ "@maven//:org_apache_spark_spark_sql_2_13",
+ "@maven//:org_jetbrains_kotlin_kotlin_stdlib",
"@maven//:org_scala_lang_scala_library",
],
)
diff --git a/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BeamExample.java b/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BeamExample.java
index 5c3aa42..6909f9f 100644
--- a/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BeamExample.java
+++ b/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BeamExample.java
@@ -185,7 +185,7 @@
/**
* Movie ids (which are group keys for this dataset) are integers from 1 to ~17000. Set public
- * groups 4500-4509.
+ * groups to a subset of them.
*/
private static PCollection<String> publiclyKnownMovieIds(Pipeline pipeline) {
var publicGroupsAsJavaList =
diff --git a/pipelinedp4j/BUILD.bazel b/pipelinedp4j/BUILD.bazel
index b6661f4..8872ad3 100644
--- a/pipelinedp4j/BUILD.bazel
+++ b/pipelinedp4j/BUILD.bazel
@@ -63,7 +63,6 @@
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/spark:spark_collections",
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/spark:spark_dp_engine_factory",
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/spark:spark_encoders",
-
],
template_file = "pom.template",
)
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel
index 4ae979f..e177b32 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel
@@ -41,6 +41,5 @@
"@maven//:com_google_guava_guava",
"@maven//:org_apache_beam_beam_sdks_java_core",
"@maven//:org_apache_beam_beam_sdks_java_extensions_avro",
-
],
)
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/PipelineDpCollection.kt b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/PipelineDpCollection.kt
index b52e2e1..0aa60df 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/PipelineDpCollection.kt
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/PipelineDpCollection.kt
@@ -16,16 +16,16 @@
package com.google.privacy.differentialprivacy.pipelinedp4j.api
+import com.google.privacy.differentialprivacy.pipelinedp4j.spark.SparkCollection
+import com.google.privacy.differentialprivacy.pipelinedp4j.spark.SparkEncoderFactory
+import org.apache.spark.sql.Dataset
import com.google.privacy.differentialprivacy.pipelinedp4j.beam.BeamCollection
import com.google.privacy.differentialprivacy.pipelinedp4j.beam.BeamEncoderFactory
import com.google.privacy.differentialprivacy.pipelinedp4j.core.EncoderFactory
import com.google.privacy.differentialprivacy.pipelinedp4j.core.FrameworkCollection
import com.google.privacy.differentialprivacy.pipelinedp4j.local.LocalCollection
import com.google.privacy.differentialprivacy.pipelinedp4j.local.LocalEncoderFactory
-import com.google.privacy.differentialprivacy.pipelinedp4j.spark.SparkCollection
-import com.google.privacy.differentialprivacy.pipelinedp4j.spark.SparkEncoderFactory
import org.apache.beam.sdk.values.PCollection as BeamPCollection
-import org.apache.spark.sql.Dataset
/**
* An internal interface to represent an arbitrary collection that is supported by PipelineDP4j.
@@ -57,7 +57,8 @@
}
/** Spark Collection represented as a Spark Dataset. */
-internal data class SparkPipelineDpCollection<T>(val data: Dataset<T>) : PipelineDpCollection<T> {
+internal data class SparkPipelineDpCollection<T>(val data: Dataset<T>) : PipelineDpCollection<T>
+{
override val encoderFactory = SparkEncoderFactory()
override fun toFrameworkCollection() = SparkCollection<T>(data)
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/QueryBuilder.kt b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/QueryBuilder.kt
index 15de370..80df376 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/QueryBuilder.kt
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/QueryBuilder.kt
@@ -17,6 +17,7 @@
package com.google.privacy.differentialprivacy.pipelinedp4j.api
import org.apache.beam.sdk.values.PCollection as BeamPCollection
+
import org.apache.spark.sql.Dataset
/**
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncoders.kt b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncoders.kt
index ff6559f..a6e05cf 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncoders.kt
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncoders.kt
@@ -21,13 +21,13 @@
import com.google.protobuf.Message
import java.io.InputStream
import java.io.OutputStream
-import kotlin.reflect.KClass
import org.apache.beam.sdk.coders.Coder
import org.apache.beam.sdk.coders.CustomCoder
import org.apache.beam.sdk.coders.DoubleCoder
import org.apache.beam.sdk.coders.StringUtf8Coder
import org.apache.beam.sdk.coders.VarIntCoder
import org.apache.beam.sdk.extensions.avro.coders.AvroCoder
+
import org.apache.beam.sdk.extensions.protobuf.ProtoCoder
class BeamEncoder<T>(val coder: Coder<T>) : Encoder<T>
@@ -39,11 +39,10 @@
override fun ints() = BeamEncoder<Int>(VarIntCoder.of())
- override fun <T : Any> records(recordClass: KClass<T>) =
- BeamEncoder<T>(AvroCoder.of(recordClass.java))
+ override fun <T : Any> records(recordClass: Class<T>) = BeamEncoder<T>(AvroCoder.of(recordClass))
- override fun <T : Message> protos(protoClass: KClass<T>) =
- BeamEncoder<T>(ProtoCoder.of(protoClass.java))
+ override fun <T : Message> protos(protoClass: Class<T>) =
+ BeamEncoder<T>(ProtoCoder.of(protoClass))
override fun <T1 : Any, T2 : Any> tuple2sOf(first: Encoder<T1>, second: Encoder<T2>) =
BeamEncoder(
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/core/Encoders.kt b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/core/Encoders.kt
index d37234a..84ebd90 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/core/Encoders.kt
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/core/Encoders.kt
@@ -40,16 +40,56 @@
/** Returns an [Encoder] for an integer value, which can be stored in a [FrameworkCollection]. */
fun ints(): Encoder<Int>
- /** Encoder for data classes. */
- fun <T : Any> records(recordClass: KClass<T>): Encoder<T>
+ /** Encoder for classes. */
+ fun <T : Any> records(recordClass: Class<T>): Encoder<T>
+
+ /** Same as [records(Class)] but accepts Kotlin class. */
+ fun <T : Any> records(recordClass: KClass<T>) = records(recordClass.java)
/** Returns an [Encoder] for a protobuf value, which can be stored in a [FrameworkCollection]. */
- fun <T : Message> protos(protoClass: KClass<T>): Encoder<T>
+ fun <T : Message> protos(protoClass: Class<T>): Encoder<T>
+
+ /** Same as [protos(Class)] but accepts Kotlin class. */
+ fun <T : Message> protos(protoClass: KClass<T>) = protos(protoClass.java)
/** Returns an [Encoder] for a pair of tuples, which can be stored in a [FrameworkCollection]. */
fun <T1 : Any, T2 : Any> tuple2sOf(first: Encoder<T1>, second: Encoder<T2>): Encoder<Pair<T1, T2>>
+
+ /**
+ * Returns the most specific [Encoder] for any record type given its [KClass], including primitive
+ * types and proto but except pairs.
+ *
+ * Use it when the record type is not known at compile time and it can be a primitive type or a
+ * proto. This method will return the most appropriate (and efficient) [Encoder] for the given
+ * type.
+ *
+ * Note that this method does not work for pairs ([tuple2sOf]) and for any other classes that are
+ * parameterized by generic types.
+ */
+ fun <T : Any> recordsOfUnknownClass(recordClass: Class<T>) =
+ when {
+ recordClass == String::class.java -> strings()
+ recordClass == Double::class.java -> doubles()
+ recordClass == Int::class.java -> ints()
+ Message::class.java.isAssignableFrom(recordClass) -> {
+ @Suppress("UNCHECKED_CAST") protos(recordClass as Class<out Message>)
+ }
+ else -> records(recordClass)
+ }
+
+ /** Same as [recordsOfUnknownClass(Class)] but accepts Kotlin class. */
+ fun <T : Any> recordsOfUnknownClass(recordClass: KClass<T>) =
+ recordsOfUnknownClass(recordClass.java)
}
+/**
+ * Inlines the function and the type parameter which allows to use [EncoderFactory.records] without
+ * specifying the class.
+ */
inline fun <reified T : Any> EncoderFactory.records() = this.records(T::class)
+/**
+ * Inlines the function and the type parameter which allows to use [EncoderFactory.protos] without
+ * specifying the class.
+ */
inline fun <reified T : Message> EncoderFactory.protos() = this.protos(T::class)
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/local/LocalEncoderFactory.kt b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/local/LocalEncoderFactory.kt
index 76489cd..3c8098d 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/local/LocalEncoderFactory.kt
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/local/LocalEncoderFactory.kt
@@ -19,7 +19,6 @@
import com.google.privacy.differentialprivacy.pipelinedp4j.core.Encoder
import com.google.privacy.differentialprivacy.pipelinedp4j.core.EncoderFactory
import com.google.protobuf.Message
-import kotlin.reflect.KClass
class LocalEncoderFactory() : EncoderFactory {
// The implementation of local encoders is empty because when the data is being processed
@@ -36,9 +35,9 @@
return object : Encoder<Int> {}
}
- override fun <T : Any> records(recordClass: KClass<T>): Encoder<T> = object : Encoder<T> {}
+ override fun <T : Any> records(recordClass: Class<T>): Encoder<T> = object : Encoder<T> {}
- override fun <T : Message> protos(protoClass: KClass<T>): Encoder<T> = object : Encoder<T> {}
+ override fun <T : Message> protos(protoClass: Class<T>): Encoder<T> = object : Encoder<T> {}
override fun <T1 : Any, T2 : Any> tuple2sOf(first: Encoder<T1>, second: Encoder<T2>) =
object : Encoder<Pair<T1, T2>> {}
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel
index 8092662..2efe020 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel
@@ -11,32 +11,34 @@
srcs = ["SparkEncoders.kt"],
deps = [
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core:encoders",
- "@maven//:com_google_protobuf_protobuf_java",
- "@maven//:org_apache_spark_spark_core_2_13",
- "@maven//:org_apache_spark_spark_sql_2_13",
- "@maven//:org_apache_spark_spark_mllib_2_13",
- "@maven//:org_apache_spark_spark_catalyst_2_13",
"@maven//:com_fasterxml_jackson_core_jackson_databind",
"@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
+ "@maven//:com_google_protobuf_protobuf_java",
+ "@maven//:org_apache_spark_spark_catalyst_2_13",
+ "@maven//:org_apache_spark_spark_core_2_13",
+ "@maven//:org_apache_spark_spark_mllib_2_13",
+ "@maven//:org_apache_spark_spark_sql_2_13",
"@maven//:org_scala_lang_scala_library",
],
)
kt_jvm_library(
name = "spark_collections",
- srcs = ["SparkCollection.kt",
- "SparkTable.kt",],
+ srcs = [
+ "SparkCollection.kt",
+ "SparkTable.kt",
+ ],
deps = [
":spark_encoders",
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core:encoders",
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core:framework_collections",
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/local:local_collections",
- "@maven//:org_apache_spark_spark_core_2_13",
- "@maven//:org_apache_spark_spark_sql_2_13",
- "@maven//:org_apache_spark_spark_mllib_2_13",
- "@maven//:org_apache_spark_spark_catalyst_2_13",
- "@maven//:com_fasterxml_jackson_core_jackson_databind",
- "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
+ "@maven//:com_fasterxml_jackson_core_jackson_databind",
+ "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
+ "@maven//:org_apache_spark_spark_catalyst_2_13",
+ "@maven//:org_apache_spark_spark_core_2_13",
+ "@maven//:org_apache_spark_spark_mllib_2_13",
+ "@maven//:org_apache_spark_spark_sql_2_13",
],
)
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncoders.kt b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncoders.kt
index 226cb3c..5d95010 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncoders.kt
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncoders.kt
@@ -19,7 +19,6 @@
import com.google.privacy.differentialprivacy.pipelinedp4j.core.Encoder
import com.google.privacy.differentialprivacy.pipelinedp4j.core.EncoderFactory
import com.google.protobuf.Message
-import kotlin.reflect.KClass
import org.apache.spark.sql.Encoders
/** A serializer and a deserializer for the data types to convert into Spark internal data types. */
@@ -39,12 +38,12 @@
return SparkEncoder<Int>(Encoders.INT())
}
- override fun <T : Any> records(recordClass: KClass<T>): SparkEncoder<T> {
- return SparkEncoder(Encoders.bean(recordClass.java))
+ override fun <T : Any> records(recordClass: Class<T>): SparkEncoder<T> {
+ return SparkEncoder(Encoders.bean(recordClass))
}
- override fun <T : Message> protos(protoClass: KClass<T>): SparkEncoder<T> {
- return SparkEncoder<T>(Encoders.kryo(protoClass.java))
+ override fun <T : Message> protos(protoClass: Class<T>): SparkEncoder<T> {
+ return SparkEncoder<T>(Encoders.kryo(protoClass))
}
override fun <T1 : Any, T2 : Any> tuple2sOf(
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel
index 692eea9..edba452 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel
@@ -21,7 +21,7 @@
"BeamQueryBuilderTest.kt",
"BeamQueryTest.kt",
"SparkQueryBuilderTest.kt",
- "SparkQueryTest.kt"
+ "SparkQueryTest.kt",
],
test_class = "com.google.privacy.differentialprivacy.pipelinedp4j.api.ApiTests",
runtime_deps = [
@@ -32,15 +32,15 @@
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core:dp_functions_params",
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core/budget:budget_spec",
"//tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark:spark_tests",
+ "@maven//:com_fasterxml_jackson_core_jackson_databind",
+ "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
"@maven//:com_google_truth_truth",
"@maven//:junit_junit",
"@maven//:org_apache_beam_beam_sdks_java_core",
- "@maven//:org_jetbrains_kotlin_kotlin_test",
- "@maven//:org_apache_spark_spark_core_2_13",
- "@maven//:org_apache_spark_spark_sql_2_13",
- "@maven//:org_apache_spark_spark_mllib_2_13",
"@maven//:org_apache_spark_spark_catalyst_2_13",
- "@maven//:com_fasterxml_jackson_core_jackson_databind",
- "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
+ "@maven//:org_apache_spark_spark_core_2_13",
+ "@maven//:org_apache_spark_spark_mllib_2_13",
+ "@maven//:org_apache_spark_spark_sql_2_13",
+ "@maven//:org_jetbrains_kotlin_kotlin_test",
],
)
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryBuilderTest.kt b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryBuilderTest.kt
index e3e046b..5902b8e 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryBuilderTest.kt
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryBuilderTest.kt
@@ -1,3 +1,19 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package com.google.privacy.differentialprivacy.pipelinedp4j.api
import com.google.common.truth.Truth.assertThat
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryTest.kt b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryTest.kt
index 2d5b5c0..bffdb88 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryTest.kt
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryTest.kt
@@ -1,3 +1,19 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package com.google.privacy.differentialprivacy.pipelinedp4j.api
import com.google.common.truth.Truth.assertThat
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BUILD.bazel b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BUILD.bazel
index b3734a9..28c0e27 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BUILD.bazel
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BUILD.bazel
@@ -37,6 +37,8 @@
"@maven//:com_google_truth_truth",
"@maven//:junit_junit",
"@maven//:org_apache_beam_beam_sdks_java_core",
+ "@maven//:org_apache_beam_beam_sdks_java_extensions_avro",
+ "@maven//:org_apache_beam_beam_sdks_java_extensions_protobuf",
"@maven//:org_hamcrest_hamcrest",
],
)
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncodersTest.kt b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncodersTest.kt
index c76201d..7a1b9b2 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncodersTest.kt
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncodersTest.kt
@@ -16,6 +16,8 @@
package com.google.privacy.differentialprivacy.pipelinedp4j.beam
+import org.apache.beam.sdk.extensions.protobuf.ProtoCoder
+import com.google.common.truth.Truth.assertThat
import com.google.privacy.differentialprivacy.pipelinedp4j.core.ContributionWithPrivacyId
import com.google.privacy.differentialprivacy.pipelinedp4j.core.contributionWithPrivacyId
import com.google.privacy.differentialprivacy.pipelinedp4j.core.encoderOfContributionWithPrivacyId
@@ -25,6 +27,10 @@
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.quantilesAccumulator
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.sumAccumulator
import com.google.protobuf.ByteString
+import org.apache.beam.sdk.coders.DoubleCoder
+import org.apache.beam.sdk.coders.StringUtf8Coder
+import org.apache.beam.sdk.coders.VarIntCoder
+import org.apache.beam.sdk.extensions.avro.coders.AvroCoder
import org.apache.beam.sdk.testing.PAssert
import org.apache.beam.sdk.testing.TestPipeline
import org.apache.beam.sdk.transforms.Create
@@ -75,21 +81,9 @@
@Test
fun records_isPossibleToCreateBeamPCollectionOfThatType() {
- val input =
- listOf(
- contributionWithPrivacyId("privacyId1", "partitionKey1", -1.0),
- contributionWithPrivacyId("privacyId2", "partitionKey1", 0.0),
- contributionWithPrivacyId("privacyId1", "partitionKey2", 1.0),
- contributionWithPrivacyId("privacyId3", "partitionKey3", 1.2345),
- )
+ val input = listOf(TestRecord("privacyId1", 1.0, -1), TestRecord("privacyId2", 2.0, 2))
val inputCoder =
- (encoderOfContributionWithPrivacyId(
- beamEncoderFactory.strings(),
- beamEncoderFactory.strings(),
- beamEncoderFactory,
- )
- as BeamEncoder<ContributionWithPrivacyId<String, String>>)
- .coder
+ (beamEncoderFactory.records(TestRecord::class) as BeamEncoder<TestRecord>).coder
val pCollection = testPipeline.apply(Create.of(input).withCoder(inputCoder))
@@ -115,7 +109,9 @@
},
compoundAccumulator {},
)
- val inputCoder = beamEncoderFactory.protos(CompoundAccumulator::class).coder
+ val inputCoder =
+ (beamEncoderFactory.protos(CompoundAccumulator::class) as BeamEncoder<CompoundAccumulator>)
+ .coder
val pCollection = testPipeline.apply(Create.of(input).withCoder(inputCoder))
@@ -137,6 +133,79 @@
testPipeline.run().waitUntilFinish()
}
+ @Test
+ fun contributionWithPrivacyIdOf_isPossibleToCreateBeamPCollectionOfThatType() {
+ val input =
+ listOf(
+ contributionWithPrivacyId("privacyId1", "partitionKey1", -1.0),
+ contributionWithPrivacyId("privacyId2", "partitionKey1", 0.0),
+ contributionWithPrivacyId("privacyId1", "partitionKey2", 1.0),
+ contributionWithPrivacyId("privacyId3", "partitionKey3", 1.2345),
+ )
+ val inputCoder =
+ (encoderOfContributionWithPrivacyId(
+ beamEncoderFactory.strings(),
+ beamEncoderFactory.strings(),
+ beamEncoderFactory,
+ )
+ as BeamEncoder<ContributionWithPrivacyId<String, String>>)
+ .coder
+
+ val pCollection = testPipeline.apply(Create.of(input).withCoder(inputCoder))
+
+ PAssert.that(pCollection).containsInAnyOrder(input)
+
+ testPipeline.run().waitUntilFinish()
+ }
+
+ @Test
+ fun recordsOfUnknownClass_string_createsEncoderWithStringCoder() {
+ @Suppress("UNCHECKED_CAST")
+ val encoder = beamEncoderFactory.recordsOfUnknownClass(String::class) as BeamEncoder<String>
+
+ assertThat(encoder.coder).isInstanceOf(StringUtf8Coder::class.java)
+ }
+
+ @Test
+ fun recordsOfUnknownClass_double_createsEncoderWithDoubleCoder() {
+ @Suppress("UNCHECKED_CAST")
+ val encoder = beamEncoderFactory.recordsOfUnknownClass(Double::class) as BeamEncoder<Double>
+
+ assertThat(encoder.coder).isInstanceOf(DoubleCoder::class.java)
+ }
+
+ @Test
+ fun recordsOfUnknownClass_int_createsEncoderWithIntCoder() {
+ @Suppress("UNCHECKED_CAST")
+ val encoder = beamEncoderFactory.recordsOfUnknownClass(Int::class) as BeamEncoder<Int>
+
+ assertThat(encoder.coder).isInstanceOf(VarIntCoder::class.java)
+ }
+
+ @Test
+ fun recordsOfUnknownClass_kotlinClass_createsEncoderWithAvroCoder() {
+ @Suppress("UNCHECKED_CAST")
+ val encoder =
+ beamEncoderFactory.recordsOfUnknownClass(TestRecord::class) as BeamEncoder<TestRecord>
+
+ assertThat(encoder.coder).isInstanceOf(AvroCoder::class.java)
+ }
+
+ @Test
+ fun recordsOfUnknownClass_proto_createsEncoderWithProtoCoder() {
+ @Suppress("UNCHECKED_CAST")
+ val encoder =
+ beamEncoderFactory.recordsOfUnknownClass(CompoundAccumulator::class)
+ as BeamEncoder<CompoundAccumulator>
+
+ assertThat(encoder.coder).isInstanceOf(ProtoCoder::class.java)
+ }
+
+ private data class TestRecord(val string: String, val double: Double, val int: Int) {
+ // Required for Beam serialization.
+ private constructor() : this("", 0.0, 0)
+ }
+
companion object {
private val beamEncoderFactory = BeamEncoderFactory()
}
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel
index 804cd08..c1cef1f 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel
@@ -31,15 +31,15 @@
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/proto:accumulators_kt_proto",
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/spark:spark_collections",
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/spark:spark_encoders",
+ "@maven//:com_fasterxml_jackson_core_jackson_databind",
+ "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_google_testparameterinjector_test_parameter_injector",
"@maven//:com_google_truth_truth",
"@maven//:junit_junit",
- "@maven//:org_apache_spark_spark_core_2_13",
- "@maven//:org_apache_spark_spark_sql_2_13",
- "@maven//:org_apache_spark_spark_mllib_2_13",
"@maven//:org_apache_spark_spark_catalyst_2_13",
- "@maven//:com_fasterxml_jackson_core_jackson_databind",
- "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
+ "@maven//:org_apache_spark_spark_core_2_13",
+ "@maven//:org_apache_spark_spark_mllib_2_13",
+ "@maven//:org_apache_spark_spark_sql_2_13",
],
)
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncodersTest.kt b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncodersTest.kt
index 420008e..1267e62 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncodersTest.kt
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncodersTest.kt
@@ -26,6 +26,7 @@
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.quantilesAccumulator
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.sumAccumulator
import com.google.protobuf.ByteString
+import org.apache.spark.sql.Encoders
import org.junit.ClassRule
import org.junit.Test
import org.junit.runner.RunWith
@@ -38,7 +39,9 @@
fun strings_isPossibleToCreateSparkCollectionOfThatType() {
val input = listOf("a", "b", "c")
val inputCoder = sparkEncoderFactory.strings().encoder
+
val dataset = sparkSession.spark.createDataset(input, inputCoder)
+
assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
}
@@ -46,7 +49,9 @@
fun doubles_isPossibleToCreateSparkCollectionOfThatType() {
val input = listOf(-1.2, 0.0, 2.1)
val inputCoder = sparkEncoderFactory.doubles().encoder
+
val dataset = sparkSession.spark.createDataset(input, inputCoder)
+
assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
}
@@ -54,28 +59,20 @@
fun ints_isPossibleToCreateSparkCollectionOfThatType() {
val input = listOf(-1, 0, 1)
val inputCoder = sparkEncoderFactory.ints().encoder
+
val dataset = sparkSession.spark.createDataset(input, inputCoder)
+
assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
}
@Test
fun records_isPossibleToCreateSparkCollectionOfThatType() {
- val input =
- listOf(
- contributionWithPrivacyId("privacyId1", "partitionKey1", -1.0),
- contributionWithPrivacyId("privacyId2", "partitionKey1", 0.0),
- contributionWithPrivacyId("privacyId1", "partitionKey2", 1.0),
- contributionWithPrivacyId("privacyId3", "partitionKey3", 1.2345),
- )
+ val input = listOf(TestRecord("privacyId1", 1.0, -1), TestRecord("privacyId2", 2.0, 2))
val inputCoder =
- (encoderOfContributionWithPrivacyId(
- sparkEncoderFactory.strings(),
- sparkEncoderFactory.strings(),
- sparkEncoderFactory,
- )
- as SparkEncoder<ContributionWithPrivacyId<String, String>>)
- .encoder
+ (sparkEncoderFactory.records(TestRecord::class) as SparkEncoder<TestRecord>).encoder
+
val dataset = sparkSession.spark.createDataset(input, inputCoder)
+
assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
}
@@ -96,8 +93,12 @@
},
compoundAccumulator {},
)
- val inputCoder = sparkEncoderFactory.protos(CompoundAccumulator::class).encoder
+ val inputCoder =
+ (sparkEncoderFactory.protos(CompoundAccumulator::class) as SparkEncoder<CompoundAccumulator>)
+ .encoder
+
val dataset = sparkSession.spark.createDataset(input, inputCoder)
+
assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
}
@@ -110,6 +111,7 @@
.encoder
val dataset = sparkSession.spark.createDataset(input, inputEncoder)
+
assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
}
@@ -129,10 +131,80 @@
sparkEncoderFactory.strings(),
)
.encoder
+
val dataset = sparkSession.spark.createDataset(input, inputEncoder)
+
assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
}
+ @Test
+ fun contributionWithPrivacyIdOf_isPossibleToCreateSparkCollectionOfThatType() {
+ val input =
+ listOf(
+ contributionWithPrivacyId("privacyId1", "partitionKey1", -1.0),
+ contributionWithPrivacyId("privacyId2", "partitionKey1", 0.0),
+ contributionWithPrivacyId("privacyId1", "partitionKey2", 1.0),
+ contributionWithPrivacyId("privacyId3", "partitionKey3", 1.2345),
+ )
+ val inputCoder =
+ (encoderOfContributionWithPrivacyId(
+ sparkEncoderFactory.strings(),
+ sparkEncoderFactory.strings(),
+ sparkEncoderFactory,
+ )
+ as SparkEncoder<ContributionWithPrivacyId<String, String>>)
+ .encoder
+
+ val dataset = sparkSession.spark.createDataset(input, inputCoder)
+
+ assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
+ }
+
+ @Test
+ fun recordsOfUnknownClass_string_createsEncoderWithSparkStringEncoder() {
+ @Suppress("UNCHECKED_CAST")
+ val encoder = sparkEncoderFactory.recordsOfUnknownClass(String::class) as SparkEncoder<String>
+
+ assertThat(encoder.encoder).isEqualTo(Encoders.STRING())
+ }
+
+ @Test
+ fun recordsOfUnknownClass_double_createsEncoderWithSparkDoubleEncoder() {
+ @Suppress("UNCHECKED_CAST")
+ val encoder = sparkEncoderFactory.recordsOfUnknownClass(Double::class) as SparkEncoder<Double>
+
+ assertThat(encoder.encoder).isEqualTo(Encoders.DOUBLE())
+ }
+
+ @Test
+ fun recordsOfUnknownClass_int_createsEncoderWithSparkIntEncoder() {
+ @Suppress("UNCHECKED_CAST")
+ val encoder = sparkEncoderFactory.recordsOfUnknownClass(Int::class) as SparkEncoder<Int>
+
+ assertThat(encoder.encoder).isEqualTo(Encoders.INT())
+ }
+
+ @Test
+ fun recordsOfUnknownClass_kotlinClass_createsEncoderWithSparkBeanEncoder() {
+ @Suppress("UNCHECKED_CAST")
+ val encoder =
+ sparkEncoderFactory.recordsOfUnknownClass(TestRecord::class) as SparkEncoder<TestRecord>
+
+ assertThat(encoder.encoder).isEqualTo(Encoders.bean(TestRecord::class.java))
+ }
+
+ @Test
+ fun recordsOfUnknownClass_proto_createsEncoderWithSparkKryoEncoder() {
+ @Suppress("UNCHECKED_CAST")
+ val encoder =
+ sparkEncoderFactory.recordsOfUnknownClass(CompoundAccumulator::class)
+ as SparkEncoder<CompoundAccumulator>
+
+ assertThat(encoder.encoder).isEqualTo(Encoders.kryo(CompoundAccumulator::class.java))
+ }
+
+ data class TestRecord(var string: String = "", var double: Double = 0.0, var int: Int = 0)
+
companion object {
@JvmField @ClassRule val sparkSession = SparkSessionRule()
private val sparkEncoderFactory = SparkEncoderFactory()
diff --git a/privacy-on-beam/pbeam/public_partitions.go b/privacy-on-beam/pbeam/public_partitions.go
index f95f6ae..145d0ac 100644
--- a/privacy-on-beam/pbeam/public_partitions.go
+++ b/privacy-on-beam/pbeam/public_partitions.go
@@ -22,12 +22,19 @@
"bytes"
"encoding/base64"
"fmt"
+ "math/rand"
"reflect"
+ "flag"
"github.com/google/differential-privacy/privacy-on-beam/v3/internal/kv"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/register"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/filter"
+)
+
+var (
+ enableShardedPublicPartitions = flag.Bool("enable_sharded_public_partitions", false, "Enable sharded public partitions. This is a temporary flag to allow us to test the new sharded implementation of public partition filtering.")
)
func init() {
@@ -51,6 +58,11 @@
register.Function4x0[beam.W, func(*beam.V) bool, func(*beam.V) bool, func(beam.W, beam.V)](mergeResultWithEmptyPublicPartitionsFn)
register.Iter1[beam.V]()
register.Emitter2[beam.W, beam.V]()
+
+ register.Function1x2[kv.Pair, ShardedKey, []byte](addRandomShardIDFn)
+ register.Function4x0[[]byte, func(*[]byte) bool, func(*int) bool, func(ShardedKey, int)](extractEmittableKeysWithShardIDFn)
+ register.DoFn4x1[ShardedKey, func(*int) bool, func(*[]byte) bool, func(beam.T, beam.W), error](&filterKeysWithShardIDFn{})
+ register.Function1x2[kv.Pair, []byte, []byte](unwrapPairFn)
}
// newAddZeroValuesToPublicPartitionsFn turns a PCollection<V> into PCollection<V,0>.
@@ -138,6 +150,116 @@
}
}
+// ShardedKey is an key encoded as bytes with a int shardID.
+type ShardedKey struct {
+ K []byte
+ ShardID int
+}
+
+func addRandomShardIDFn(encoded kv.Pair) (ShardedKey, []byte) {
+ return ShardedKey{K: encoded.K, ShardID: rand.Intn(2048)}, encoded.V
+}
+
+func extractEmittableKeysWithShardIDFn(k []byte, isAllowedKeyIter func(*[]byte) bool, shardIDIter func(*int) bool, emit func(ShardedKey, int)) {
+ var isAllowedKey []byte
+ if !isAllowedKeyIter(&isAllowedKey) {
+ // k is not an allow listed key, filter it out.
+ return
+ }
+
+ var subkey int
+ for shardIDIter(&subkey) {
+ emit(ShardedKey{K: k, ShardID: subkey}, 0)
+ }
+}
+
+type filterKeysWithShardIDFn struct {
+ KType beam.EncodedType
+ VType beam.EncodedType
+ PairCodec *kv.Codec
+}
+
+func newFilterKeysWithShardIDFn(kType, vType reflect.Type) *filterKeysWithShardIDFn {
+ return &filterKeysWithShardIDFn{
+ KType: beam.EncodedType{T: kType},
+ VType: beam.EncodedType{T: vType},
+ }
+}
+
+func (fn *filterKeysWithShardIDFn) Setup() error {
+ fn.PairCodec = kv.NewCodec(fn.KType.T, fn.VType.T)
+ return fn.PairCodec.Setup()
+}
+
+func (fn *filterKeysWithShardIDFn) ProcessElement(k ShardedKey, isEmittableShardIDIter func(*int) bool, pcolValueIter func(*[]byte) bool, emit func(beam.T, beam.W)) error {
+ var isEmittableShardID int
+ if !isEmittableShardIDIter(&isEmittableShardID) {
+ // k isn't a key from the public partitions collection.
+ return nil
+ }
+
+ var pcolValue []byte
+ for pcolValueIter(&pcolValue) {
+ k, v, err := fn.PairCodec.Decode(kv.Pair{K: k.K, V: pcolValue})
+ if err != nil {
+ return err
+ }
+ emit(k, v)
+ }
+
+ return nil
+}
+
+func unwrapPairFn(encoded kv.Pair) ([]byte, []byte) {
+ return encoded.K, encoded.V
+}
+
+func unwrapShardedKeyFn(shardedKey ShardedKey) ([]byte, int) {
+ return shardedKey.K, shardedKey.ShardID
+}
+
+// Filters out KV-s from col that have a key in 'keys'.
+//
+// A single key in col may have a huge number of values. This function handles that
+// case by sharding col randomly into 2048 collections, and then joining the keys
+// within each of those shards. This reduces stragglers when there is a hot key in col,
+// since its processing is parallelized 2048 ways.
+//
+// Each value in col is randomly selected to be in one of the 2048 shards. Then the
+// 'keys' collection is joined with the sharded col collection to find the shardIds
+// that need to be present per key. Then the sharded col collection is joined with
+// the sharded 'keys' collection to compute the final filtered result.
+func filterKeysImbalanced(s beam.Scope, col beam.PCollection, keys beam.PCollection) beam.PCollection {
+ s = s.Scope("filterKeysImbalanced")
+
+ kT, vT := beam.ValidateKVType(col)
+
+ // Add a random shardId (one of 2048) to each element in col.
+ // PCollection<KV<ShardedKey[key, randInt[0;2048)], []byte>>
+ pcolAsBytesWithShardID := beam.ParDo(s, addRandomShardIDFn, beam.ParDo(s, kv.NewEncodeFn(kT, vT), col))
+ // Drop values and remove duplicates.
+ // PCollection<ShardedKey[key, randInt[0;2048)]>
+ uniqueKeysWithShardID := filter.Distinct(s, beam.DropValue(s, pcolAsBytesWithShardID))
+
+ // Prepare the keys for a CoGroupBy with uniqueSubkeys.
+ // PCollection<KV<key, 0>>
+ keysWithZero := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, keys)
+ // PCollection<kv.Pair<key, 0>>
+ keysEncodedWithZero := beam.ParDo(
+ s,
+ kv.NewEncodeFn(keysWithZero.Type().Components()[0],
+ keysWithZero.Type().Components()[1]), keysWithZero)
+
+ // Find the shardIds per key in the keys collection.
+ // PCollection<KV<key, *>>
+ groupedByKey := beam.CoGroupByKey(s, beam.ParDo(s, unwrapPairFn, keysEncodedWithZero), beam.ParDo(s, unwrapShardedKeyFn, uniqueKeysWithShardID))
+ emittableKeysWithShardID := beam.ParDo(s, extractEmittableKeysWithShardIDFn, groupedByKey)
+
+ // Finally perform the sharded filter.
+ groupedByKeyAndShardID := beam.CoGroupByKey(s, emittableKeysWithShardID, pcolAsBytesWithShardID)
+ return beam.ParDo(s, newFilterKeysWithShardIDFn(kT.Type(), vT.Type()), groupedByKeyAndShardID, beam.TypeDefinition{Var: beam.TType, T: kT.Type()}, beam.TypeDefinition{Var: beam.WType, T: vT.Type()})
+}
+
// dropNonPublicPartitionsVFn drops partitions not specified in
// PublicPartitions from pcol. It can be used for aggregations on V values,
// e.g. Count and DistinctPrivacyID.
@@ -152,9 +274,13 @@
// Returns a PCollection<PrivacyKey, Value> only for values present in
// publicPartitions.
func dropNonPublicPartitionsVFn(s beam.Scope, publicPartitions beam.PCollection, pcol PrivatePCollection) beam.PCollection {
- publicPartitionsWithZeros := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, publicPartitions)
- groupedByValue := beam.CoGroupByKey(s, publicPartitionsWithZeros, beam.SwapKV(s, pcol.col))
- return beam.ParDo(s, mergePublicValues, groupedByValue)
+ if *enableShardedPublicPartitions {
+ return beam.SwapKV(s, filterKeysImbalanced(s, beam.SwapKV(s, pcol.col), publicPartitions))
+ } else {
+ publicPartitionsWithZeros := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, publicPartitions)
+ groupedByValue := beam.CoGroupByKey(s, publicPartitionsWithZeros, beam.SwapKV(s, pcol.col))
+ return beam.ParDo(s, mergePublicValues, groupedByValue)
+ }
}
// dropNonPublicPartitionsKVFn drops partitions not specified in
@@ -175,12 +301,20 @@
// Returns a PCollection<PrivacyKey, <PartitionKey, Value>> only for values present in
// publicPartitions.
func dropNonPublicPartitionsKVFn(s beam.Scope, publicPartitions beam.PCollection, pcol PrivatePCollection, idType typex.FullType) beam.PCollection {
- publicPartitionsWithZeros := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, publicPartitions)
- encodedIDV := beam.ParDo(s, newEncodeIDVFn(idType, pcol.codec), pcol.col, beam.TypeDefinition{Var: beam.WType, T: pcol.codec.KType.T})
- groupedByValue := beam.CoGroupByKey(s, publicPartitionsWithZeros, encodedIDV)
- merged := beam.SwapKV(s, beam.ParDo(s, mergePublicValues, groupedByValue))
- decodeFn := newDecodeIDVFn(pcol.codec.KType, kv.NewCodec(idType.Type(), pcol.codec.VType.T))
- return beam.ParDo(s, decodeFn, merged, beam.TypeDefinition{Var: beam.UType, T: idType.Type()})
+ if *enableShardedPublicPartitions {
+ encodedIDV := beam.ParDo(
+ s, newEncodeIDVFn(idType, pcol.codec), pcol.col, beam.TypeDefinition{Var: beam.WType, T: pcol.codec.KType.T})
+ filteredEncodedIDV := filterKeysImbalanced(s, encodedIDV, publicPartitions)
+ decodeFn := newDecodeIDVFn(pcol.codec.KType, kv.NewCodec(idType.Type(), pcol.codec.VType.T))
+ return beam.ParDo(s, decodeFn, filteredEncodedIDV, beam.TypeDefinition{Var: beam.UType, T: idType.Type()})
+ } else {
+ publicPartitionsWithZeros := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, publicPartitions)
+ encodedIDV := beam.ParDo(s, newEncodeIDVFn(idType, pcol.codec), pcol.col, beam.TypeDefinition{Var: beam.WType, T: pcol.codec.KType.T})
+ groupedByValue := beam.CoGroupByKey(s, publicPartitionsWithZeros, encodedIDV)
+ merged := beam.SwapKV(s, beam.ParDo(s, mergePublicValues, groupedByValue))
+ decodeFn := newDecodeIDVFn(pcol.codec.KType, kv.NewCodec(idType.Type(), pcol.codec.VType.T))
+ return beam.ParDo(s, decodeFn, merged, beam.TypeDefinition{Var: beam.UType, T: idType.Type()})
+ }
}
// encodeIDVFn takes a PCollection<ID,kv.Pair{K,V}> as input, and returns a
diff --git a/privacy-on-beam/pbeam/public_partitions_test.go b/privacy-on-beam/pbeam/public_partitions_test.go
index 54e83e1..1ed8a00 100644
--- a/privacy-on-beam/pbeam/public_partitions_test.go
+++ b/privacy-on-beam/pbeam/public_partitions_test.go
@@ -22,6 +22,7 @@
"reflect"
"testing"
+ "flag"
"github.com/google/differential-privacy/privacy-on-beam/v3/pbeam/testutils"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
@@ -104,6 +105,40 @@
}
}
+// TODO: Remove once the enable_sharded_public_partitions flag is gone.
+func TestDropNonPublicPartitionsVFnShardedImpl(t *testing.T) {
+ flag.Set("enable_sharded_public_partitions", "true")
+
+ pairs := testutils.ConcatenatePairs(
+ testutils.MakePairsWithFixedV(7, 0),
+ testutils.MakePairsWithFixedVStartingFromKey(7, 10, 1),
+ testutils.MakePairsWithFixedVStartingFromKey(17, 83, 2),
+ testutils.MakePairsWithFixedVStartingFromKey(100, 10, 3),
+ )
+
+ // Keep partitions 0, 2;
+ // drop partitions 1, 3.
+ result := testutils.ConcatenatePairs(
+ testutils.MakePairsWithFixedV(7, 0),
+ testutils.MakePairsWithFixedVStartingFromKey(17, 83, 2),
+ )
+
+ p, s, col, want := ptest.CreateList2(pairs, result)
+ want = beam.ParDo(s, testutils.PairToKV, want)
+ col = beam.ParDo(s, testutils.PairToKV, col)
+ partitions := []int{0, 2}
+
+ partitionsCol := beam.CreateList(s, partitions)
+ epsilon := 50.0
+ pcol := MakePrivate(s, col, privacySpec(t, PrivacySpecParams{AggregationEpsilon: epsilon}))
+ got := dropNonPublicPartitionsVFn(s, partitionsCol, pcol)
+ testutils.EqualsKVInt(t, s, got, want)
+ if err := ptest.Run(p); err != nil {
+ t.Errorf("DropNonPublicPartitionsVFn did not drop non public partitions as expected: %v", err)
+ }
+ flag.Set("enable_sharded_public_partitions", "false")
+}
+
// TestDropNonPublicPartitionsKVFn checks that int elements with non-public partitions
// are dropped (tests function used for sum and mean).
func TestDropNonPublicPartitionsKVFn(t *testing.T) {
@@ -147,6 +182,51 @@
}
}
+// TODO: Remove once the enable_sharded_public_partitions flag is gone.
+func TestDropNonPublicPartitionsKVFnShardedImpl(t *testing.T) {
+ flag.Set("enable_sharded_public_partitions", "true")
+
+ triples := testutils.ConcatenateTriplesWithIntValue(
+ testutils.MakeTripleWithIntValueStartingFromKey(0, 7, 0, 0),
+ testutils.MakeTripleWithIntValueStartingFromKey(7, 3, 1, 0),
+ testutils.MakeTripleWithIntValueStartingFromKey(10, 90, 2, 0),
+ testutils.MakeTripleWithIntValueStartingFromKey(100, 100, 11, 0),
+ testutils.MakeTripleWithIntValueStartingFromKey(200, 5, 12, 0))
+ // Keep partitions 0, 2.
+ // Drop partitions 1, 33, 100.
+ result := testutils.ConcatenateTriplesWithIntValue(
+ testutils.MakeTripleWithIntValueStartingFromKey(0, 7, 0, 0),
+ testutils.MakeTripleWithIntValueStartingFromKey(10, 90, 2, 0))
+
+ p, s, col, col2 := ptest.CreateList2(triples, result)
+ // Doesn't matter that the values 3, 4, 5, 6, 9, 10
+ // are in the partitions PCollection because we are
+ // just dropping the values that are in our original PCollection
+ // that are not in public partitions.
+ partitionsCol := beam.CreateList(s, []int{0, 2, 3, 4, 5, 6, 9, 10})
+ col = beam.ParDo(s, testutils.ExtractIDFromTripleWithIntValue, col)
+ col2 = beam.ParDo(s, testutils.ExtractIDFromTripleWithIntValue, col2)
+ epsilon := 50.0
+
+ pcol := MakePrivate(s, col, privacySpec(t, PrivacySpecParams{AggregationEpsilon: epsilon}))
+ pcol = ParDo(s, testutils.TripleWithIntValueToKV, pcol)
+ idT, _ := beam.ValidateKVType(pcol.col)
+
+ got := dropNonPublicPartitionsKVFn(s, partitionsCol, pcol, idT)
+ got = beam.SwapKV(s, got)
+
+ pcol2 := MakePrivate(s, col2, privacySpec(t, PrivacySpecParams{AggregationEpsilon: epsilon}))
+ pcol2 = ParDo(s, testutils.TripleWithIntValueToKV, pcol2)
+ want := pcol2.col
+ want = beam.SwapKV(s, want)
+
+ testutils.EqualsKVInt(t, s, got, want)
+ if err := ptest.Run(p); err != nil {
+ t.Errorf("TestDropNonPublicPartitionsKVFn did not drop non public partitions as expected: %v", err)
+ }
+ flag.Set("enable_sharded_public_partitions", "false")
+}
+
// Check that float elements with non-public partitions
// are dropped (tests function used for sum and mean).
func TestDropNonPublicPartitionsFloat(t *testing.T) {
@@ -190,3 +270,49 @@
t.Errorf("TestDropNonPublicPartitionsFloat did not drop non public partitions as expected: %v", err)
}
}
+
+// TODO: Remove once the enable_sharded_public_partitions flag is gone.
+func TestDropNonPublicPartitionsFloatShardedImpl(t *testing.T) {
+ flag.Set("enable_sharded_public_partitions", "true")
+
+ // In this test, we check that non-public partitions
+ // are dropped. This function is used for sum and mean.
+ // Used example values from the mean test.
+ triples := testutils.ConcatenateTriplesWithFloatValue(
+ testutils.MakeTripleWithFloatValue(7, 0, 2.0),
+ testutils.MakeTripleWithFloatValueStartingFromKey(7, 100, 1, 1.3),
+ testutils.MakeTripleWithFloatValueStartingFromKey(107, 150, 1, 2.5),
+ )
+ // Keep partition 0.
+ // drop partition 1.
+ result := testutils.ConcatenateTriplesWithFloatValue(
+ testutils.MakeTripleWithFloatValue(7, 0, 2.0))
+
+ p, s, col, col2 := ptest.CreateList2(triples, result)
+
+ // Doesn't matter that the values 2, 3, 4, 5, 6, 7 are in the partitions PCollection.
+ // We are just dropping the values that are in our original PCollection that are not in
+ // public partitions.
+ partitionsCol := beam.CreateList(s, []int{0, 2, 3, 4, 5, 6, 7})
+ col = beam.ParDo(s, testutils.ExtractIDFromTripleWithFloatValue, col)
+ col2 = beam.ParDo(s, testutils.ExtractIDFromTripleWithFloatValue, col2)
+ epsilon := 50.0
+
+ pcol := MakePrivate(s, col, privacySpec(t, PrivacySpecParams{AggregationEpsilon: epsilon}))
+ pcol = ParDo(s, testutils.TripleWithFloatValueToKV, pcol)
+ idT, _ := beam.ValidateKVType(pcol.col)
+
+ got := dropNonPublicPartitionsKVFn(s, partitionsCol, pcol, idT)
+ got = beam.SwapKV(s, got)
+
+ pcol2 := MakePrivate(s, col2, privacySpec(t, PrivacySpecParams{AggregationEpsilon: epsilon}))
+ pcol2 = ParDo(s, testutils.TripleWithFloatValueToKV, pcol2)
+ want := pcol2.col
+ want = beam.SwapKV(s, want)
+
+ testutils.EqualsKVInt(t, s, got, want)
+ if err := ptest.Run(p); err != nil {
+ t.Errorf("TestDropNonPublicPartitionsFloat did not drop non public partitions as expected: %v", err)
+ }
+ flag.Set("enable_sharded_public_partitions", "false")
+}
diff --git a/privacy-on-beam/pbeam/testutils/testutils.go b/privacy-on-beam/pbeam/testutils/testutils.go
index 04851e6..8151f2d 100644
--- a/privacy-on-beam/pbeam/testutils/testutils.go
+++ b/privacy-on-beam/pbeam/testutils/testutils.go
@@ -24,6 +24,7 @@
"math"
"math/big"
"reflect"
+ "sort"
"testing"
"github.com/google/differential-privacy/go/v3/dpagg"
@@ -510,6 +511,7 @@
for vIter(&v) {
vSlice = append(vSlice, float64(v))
}
+ sort.Float64s(vSlice)
return vSlice
}