Add Spark to PipelineDP4j, sharded key filtering to PBeam

Export of internal changes.
--
5478b1957ce80be498abac882a75dd958e1b966e by Differential Privacy Team <noreply@google.com>:

Introduce AboveThresholdSelector to StreamingPartitionSelector

PiperOrigin-RevId: 700987846
Change-Id: Ia6b8f8b2d5da936a3d6229c180d59dc176a89970

--
050bc0203c0cb5d07e67bd7e0dab20b348ddd564 by Differential Privacy Team <noreply@google.com>:

n/a

PiperOrigin-RevId: 700310388
Change-Id: Ib0867257ab65acc7bfc37b20aac49211ca3bfec4
GitOrigin-RevId: 11fd989e0ed941bca0f189400d54ff464c4f4238
diff --git a/.gitignore b/.gitignore
index 9d1acbb..d883a6d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,12 +2,3 @@
 **/bazel-java
 **/bazel-out
 **/bazel-testlogs
-
-**/bazel-differential-privacy
-**/.ijwb/
-**/pipelinedp4j/.ijwb/
-**/pipelinedp4j/bazel-pipelinedp4j
-**/pipelinedp4j/MODULE**
-**/examples/.idea/
-**/examples/pipelinedp4j/bazel-pipelinedp4j
-**/examples/pipelinedp4j/MODULE**
diff --git a/cc/algorithms/BUILD b/cc/algorithms/BUILD
index 2ab7d58..764e780 100644
--- a/cc/algorithms/BUILD
+++ b/cc/algorithms/BUILD
@@ -139,7 +139,6 @@
     deps = [
         ":algorithm",
         ":approx-bounds",
-        ":bounded-algorithm",
         ":numerical-mechanisms",
         ":util",
         "//proto:util-lib",
@@ -245,7 +244,6 @@
     deps = [
         ":algorithm",
         ":approx-bounds",
-        ":bounded-algorithm",
         ":numerical-mechanisms",
         ":util",
         "//proto:util-lib",
@@ -291,10 +289,8 @@
     deps = [
         ":algorithm",
         ":approx-bounds",
-        ":bounded-algorithm",
         ":bounded-variance",
         ":numerical-mechanisms",
-        ":util",
         "//proto:util-lib",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/status",
@@ -534,6 +530,7 @@
 cc_library(
     name = "bounded-algorithm",
     hdrs = ["bounded-algorithm.h"],
+    visibility = ["//visibility:private"],
     deps = [
         ":algorithm",
         ":approx-bounds",
@@ -668,10 +665,14 @@
     visibility = ["//visibility:public"],
     deps = [
         ":algorithm",
-        ":bounded-algorithm",
+        ":numerical-mechanisms",
         ":quantile-tree",
+        ":util",
+        "@com_google_absl//absl/log",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_cc_differential_privacy//base:status_macros",
     ],
 )
 
diff --git a/cc/algorithms/bounded-standard-deviation.h b/cc/algorithms/bounded-standard-deviation.h
index 31676dd..96fad76 100644
--- a/cc/algorithms/bounded-standard-deviation.h
+++ b/cc/algorithms/bounded-standard-deviation.h
@@ -28,10 +28,8 @@
 #include "absl/status/statusor.h"
 #include "algorithms/algorithm.h"
 #include "algorithms/approx-bounds.h"
-#include "algorithms/bounded-algorithm.h"
 #include "algorithms/bounded-variance.h"
 #include "algorithms/numerical-mechanisms.h"
-#include "algorithms/util.h"
 #include "proto/util.h"
 #include "proto/data.pb.h"
 #include "proto/summary.pb.h"
diff --git a/cc/algorithms/bounded-sum.h b/cc/algorithms/bounded-sum.h
index 784ad2c..5de00b2 100644
--- a/cc/algorithms/bounded-sum.h
+++ b/cc/algorithms/bounded-sum.h
@@ -31,13 +31,11 @@
 
 #include "google/protobuf/any.pb.h"
 #include "absl/log/log.h"
-#include "absl/memory/memory.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/str_cat.h"
 #include "algorithms/algorithm.h"
 #include "algorithms/approx-bounds.h"
-#include "algorithms/bounded-algorithm.h"
 #include "algorithms/numerical-mechanisms.h"
 #include "algorithms/util.h"
 #include "proto/util.h"
diff --git a/cc/algorithms/bounded-variance.h b/cc/algorithms/bounded-variance.h
index ad4080e..b298977 100644
--- a/cc/algorithms/bounded-variance.h
+++ b/cc/algorithms/bounded-variance.h
@@ -37,7 +37,6 @@
 #include "absl/strings/str_cat.h"
 #include "algorithms/algorithm.h"
 #include "algorithms/approx-bounds.h"
-#include "algorithms/bounded-algorithm.h"
 #include "algorithms/numerical-mechanisms.h"
 #include "algorithms/util.h"
 #include "proto/util.h"
diff --git a/cc/algorithms/quantiles.h b/cc/algorithms/quantiles.h
index 119bb73..fbc2bf2 100644
--- a/cc/algorithms/quantiles.h
+++ b/cc/algorithms/quantiles.h
@@ -18,13 +18,21 @@
 #define DIFFERENTIAL_PRIVACY_CPP_ALGORITHMS_QUANTILES_H_
 
 #include <cstdint>
+#include <memory>
 #include <optional>
+#include <type_traits>
+#include <utility>
+#include <vector>
 
+#include "absl/log/log.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
 #include "algorithms/algorithm.h"
-#include "algorithms/bounded-algorithm.h"
+#include "algorithms/numerical-mechanisms.h"
 #include "algorithms/quantile-tree.h"
+#include "algorithms/util.h"
+#include "base/status_macros.h"
 
 namespace differential_privacy {
 
@@ -239,7 +247,7 @@
   int max_partitions_contributed_ = 1;
   int max_contributions_per_partition_ = 1;
   std::unique_ptr<NumericalMechanismBuilder> mechanism_builder_ =
-      absl::make_unique<LaplaceMechanism::Builder>();
+      std::make_unique<LaplaceMechanism::Builder>();
   std::vector<double> quantiles_;
 
   static absl::Status ValidateQuantiles(std::vector<double>& quantiles) {
diff --git a/cc/algorithms/util.cc b/cc/algorithms/util.cc
index aadcdff..f63b0b9 100644
--- a/cc/algorithms/util.cc
+++ b/cc/algorithms/util.cc
@@ -303,6 +303,10 @@
                             "contributed to (i.e., L0 sensitivity)");
 }
 
+absl::Status ValidateMaxWindows(std::optional<int> max_windows) {
+  return ValidateIsPositive(max_windows, "Maximum number of windows");
+}
+
 absl::Status ValidateMaxContributionsPerPartition(
     std::optional<double> max_contributions_per_partition) {
   return ValidateIsPositive(max_contributions_per_partition,
diff --git a/cc/algorithms/util.h b/cc/algorithms/util.h
index 39001ca..389016f 100644
--- a/cc/algorithms/util.h
+++ b/cc/algorithms/util.h
@@ -522,6 +522,7 @@
 absl::Status ValidateDelta(std::optional<double> delta);
 absl::Status ValidateMaxPartitionsContributed(
     std::optional<double> max_partitions_contributed);
+absl::Status ValidateMaxWindows(std::optional<int> max_windows);
 absl::Status ValidateMaxContributionsPerPartition(
     std::optional<double> max_contributions_per_partition);
 absl::Status ValidateMaxContributions(std::optional<int> max_contributions);
diff --git a/examples/pipelinedp4j/README.md b/examples/pipelinedp4j/README.md
index b95245f..df1f718 100644
--- a/examples/pipelinedp4j/README.md
+++ b/examples/pipelinedp4j/README.md
@@ -90,7 +90,7 @@
 1.  Run the program:
 
     ```shell
-    bazel-bin/BeamExample --inputFilePath=netflix_data.csv --outputFilePath=output.txt
+    bazel-bin/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BeamExample --inputFilePath=netflix_data.csv --outputFilePath=output.txt
     ```
 
 1.  View the results:
diff --git a/examples/pipelinedp4j/pom.xml b/examples/pipelinedp4j/pom.xml
index 4aaed76..af5066c 100644
--- a/examples/pipelinedp4j/pom.xml
+++ b/examples/pipelinedp4j/pom.xml
@@ -46,6 +46,10 @@
                 <configuration>
                     <source>11</source>
                     <target>11</target>
+                    <excludes>
+                        <!-- SparkExample is not supoprted yet. -->
+                        <exclude>com/google/privacy/differentialprivacy/pipelinedp4j/examples/SparkExample.java</exclude>
+                    </excludes>
                 </configuration>
             </plugin>
         </plugins>
diff --git a/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BUILD.bazel b/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BUILD.bazel
index a731aae..26f721e 100644
--- a/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BUILD.bazel
+++ b/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BUILD.bazel
@@ -37,23 +37,22 @@
 java_binary(
     name = "SparkExample",
     srcs = [
-        "SparkExample.java",
         "MovieMetrics.java",
         "MovieView.java",
+        "SparkExample.java",
     ],
     main_class = "com.google.privacy.differentialprivacy.pipelinedp4j.examples.SparkExample",
     deps = [
         "@com_google_privacy_differentialprivacy_pipielinedp4j//main/com/google/privacy/differentialprivacy/pipelinedp4j/api",
-        "@maven//:com_google_guava_guava",
-        "@maven//:info_picocli_picocli",
-        "@maven//:org_jetbrains_kotlin_kotlin_stdlib",
-
-        "@maven//:org_apache_spark_spark_core_2_13",
-        "@maven//:org_apache_spark_spark_sql_2_13",
-        "@maven//:org_apache_spark_spark_mllib_2_13",
-        "@maven//:org_apache_spark_spark_catalyst_2_13",
         "@maven//:com_fasterxml_jackson_core_jackson_databind",
         "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
+        "@maven//:com_google_guava_guava",
+        "@maven//:info_picocli_picocli",
+        "@maven//:org_apache_spark_spark_catalyst_2_13",
+        "@maven//:org_apache_spark_spark_core_2_13",
+        "@maven//:org_apache_spark_spark_mllib_2_13",
+        "@maven//:org_apache_spark_spark_sql_2_13",
+        "@maven//:org_jetbrains_kotlin_kotlin_stdlib",
         "@maven//:org_scala_lang_scala_library",
     ],
 )
diff --git a/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BeamExample.java b/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BeamExample.java
index 5c3aa42..6909f9f 100644
--- a/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BeamExample.java
+++ b/examples/pipelinedp4j/src/main/java/com/google/privacy/differentialprivacy/pipelinedp4j/examples/BeamExample.java
@@ -185,7 +185,7 @@
 
   /**
    * Movie ids (which are group keys for this dataset) are integers from 1 to ~17000. Set public
-   * groups 4500-4509.
+   * groups to a subset of them.
    */
   private static PCollection<String> publiclyKnownMovieIds(Pipeline pipeline) {
     var publicGroupsAsJavaList =
diff --git a/pipelinedp4j/BUILD.bazel b/pipelinedp4j/BUILD.bazel
index b6661f4..8872ad3 100644
--- a/pipelinedp4j/BUILD.bazel
+++ b/pipelinedp4j/BUILD.bazel
@@ -63,7 +63,6 @@
         "//main/com/google/privacy/differentialprivacy/pipelinedp4j/spark:spark_collections",
         "//main/com/google/privacy/differentialprivacy/pipelinedp4j/spark:spark_dp_engine_factory",
         "//main/com/google/privacy/differentialprivacy/pipelinedp4j/spark:spark_encoders",
-
     ],
     template_file = "pom.template",
 )
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel
index 4ae979f..e177b32 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel
@@ -41,6 +41,5 @@
         "@maven//:com_google_guava_guava",
         "@maven//:org_apache_beam_beam_sdks_java_core",
         "@maven//:org_apache_beam_beam_sdks_java_extensions_avro",
-
     ],
 )
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/PipelineDpCollection.kt b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/PipelineDpCollection.kt
index b52e2e1..0aa60df 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/PipelineDpCollection.kt
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/PipelineDpCollection.kt
@@ -16,16 +16,16 @@
 
 package com.google.privacy.differentialprivacy.pipelinedp4j.api
 
+import com.google.privacy.differentialprivacy.pipelinedp4j.spark.SparkCollection
+import com.google.privacy.differentialprivacy.pipelinedp4j.spark.SparkEncoderFactory
+import org.apache.spark.sql.Dataset
 import com.google.privacy.differentialprivacy.pipelinedp4j.beam.BeamCollection
 import com.google.privacy.differentialprivacy.pipelinedp4j.beam.BeamEncoderFactory
 import com.google.privacy.differentialprivacy.pipelinedp4j.core.EncoderFactory
 import com.google.privacy.differentialprivacy.pipelinedp4j.core.FrameworkCollection
 import com.google.privacy.differentialprivacy.pipelinedp4j.local.LocalCollection
 import com.google.privacy.differentialprivacy.pipelinedp4j.local.LocalEncoderFactory
-import com.google.privacy.differentialprivacy.pipelinedp4j.spark.SparkCollection
-import com.google.privacy.differentialprivacy.pipelinedp4j.spark.SparkEncoderFactory
 import org.apache.beam.sdk.values.PCollection as BeamPCollection
-import org.apache.spark.sql.Dataset
 
 /**
  * An internal interface to represent an arbitrary collection that is supported by PipelineDP4j.
@@ -57,7 +57,8 @@
 }
 
 /** Spark Collection represented as a Spark Dataset. */
-internal data class SparkPipelineDpCollection<T>(val data: Dataset<T>) : PipelineDpCollection<T> {
+internal data class SparkPipelineDpCollection<T>(val data: Dataset<T>) : PipelineDpCollection<T>
+{
   override val encoderFactory = SparkEncoderFactory()
 
   override fun toFrameworkCollection() = SparkCollection<T>(data)
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/QueryBuilder.kt b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/QueryBuilder.kt
index 15de370..80df376 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/QueryBuilder.kt
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/QueryBuilder.kt
@@ -17,6 +17,7 @@
 package com.google.privacy.differentialprivacy.pipelinedp4j.api
 
 import org.apache.beam.sdk.values.PCollection as BeamPCollection
+
 import org.apache.spark.sql.Dataset
 
 /**
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncoders.kt b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncoders.kt
index ff6559f..a6e05cf 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncoders.kt
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncoders.kt
@@ -21,13 +21,13 @@
 import com.google.protobuf.Message
 import java.io.InputStream
 import java.io.OutputStream
-import kotlin.reflect.KClass
 import org.apache.beam.sdk.coders.Coder
 import org.apache.beam.sdk.coders.CustomCoder
 import org.apache.beam.sdk.coders.DoubleCoder
 import org.apache.beam.sdk.coders.StringUtf8Coder
 import org.apache.beam.sdk.coders.VarIntCoder
 import org.apache.beam.sdk.extensions.avro.coders.AvroCoder
+
 import org.apache.beam.sdk.extensions.protobuf.ProtoCoder
 
 class BeamEncoder<T>(val coder: Coder<T>) : Encoder<T>
@@ -39,11 +39,10 @@
 
   override fun ints() = BeamEncoder<Int>(VarIntCoder.of())
 
-  override fun <T : Any> records(recordClass: KClass<T>) =
-    BeamEncoder<T>(AvroCoder.of(recordClass.java))
+  override fun <T : Any> records(recordClass: Class<T>) = BeamEncoder<T>(AvroCoder.of(recordClass))
 
-  override fun <T : Message> protos(protoClass: KClass<T>) =
-    BeamEncoder<T>(ProtoCoder.of(protoClass.java))
+  override fun <T : Message> protos(protoClass: Class<T>) =
+    BeamEncoder<T>(ProtoCoder.of(protoClass))
 
   override fun <T1 : Any, T2 : Any> tuple2sOf(first: Encoder<T1>, second: Encoder<T2>) =
     BeamEncoder(
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/core/Encoders.kt b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/core/Encoders.kt
index d37234a..84ebd90 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/core/Encoders.kt
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/core/Encoders.kt
@@ -40,16 +40,56 @@
   /** Returns an [Encoder] for an integer value, which can be stored in a [FrameworkCollection]. */
   fun ints(): Encoder<Int>
 
-  /** Encoder for data classes. */
-  fun <T : Any> records(recordClass: KClass<T>): Encoder<T>
+  /** Encoder for classes. */
+  fun <T : Any> records(recordClass: Class<T>): Encoder<T>
+
+  /** Same as [records(Class)] but accepts Kotlin class. */
+  fun <T : Any> records(recordClass: KClass<T>) = records(recordClass.java)
 
   /** Returns an [Encoder] for a protobuf value, which can be stored in a [FrameworkCollection]. */
-  fun <T : Message> protos(protoClass: KClass<T>): Encoder<T>
+  fun <T : Message> protos(protoClass: Class<T>): Encoder<T>
+
+  /** Same as [protos(Class)] but accepts Kotlin class. */
+  fun <T : Message> protos(protoClass: KClass<T>) = protos(protoClass.java)
 
   /** Returns an [Encoder] for a pair of tuples, which can be stored in a [FrameworkCollection]. */
   fun <T1 : Any, T2 : Any> tuple2sOf(first: Encoder<T1>, second: Encoder<T2>): Encoder<Pair<T1, T2>>
+
+  /**
+   * Returns the most specific [Encoder] for any record type given its [KClass], including primitive
+   * types and proto but except pairs.
+   *
+   * Use it when the record type is not known at compile time and it can be a primitive type or a
+   * proto. This method will return the most appropriate (and efficient) [Encoder] for the given
+   * type.
+   *
+   * Note that this method does not work for pairs ([tuple2sOf]) and for any other classes that are
+   * parameterized by generic types.
+   */
+  fun <T : Any> recordsOfUnknownClass(recordClass: Class<T>) =
+    when {
+      recordClass == String::class.java -> strings()
+      recordClass == Double::class.java -> doubles()
+      recordClass == Int::class.java -> ints()
+      Message::class.java.isAssignableFrom(recordClass) -> {
+        @Suppress("UNCHECKED_CAST") protos(recordClass as Class<out Message>)
+      }
+      else -> records(recordClass)
+    }
+
+  /** Same as [recordsOfUnknownClass(Class)] but accepts Kotlin class. */
+  fun <T : Any> recordsOfUnknownClass(recordClass: KClass<T>) =
+    recordsOfUnknownClass(recordClass.java)
 }
 
+/**
+ * Inlines the function and the type parameter which allows to use [EncoderFactory.records] without
+ * specifying the class.
+ */
 inline fun <reified T : Any> EncoderFactory.records() = this.records(T::class)
 
+/**
+ * Inlines the function and the type parameter which allows to use [EncoderFactory.protos] without
+ * specifying the class.
+ */
 inline fun <reified T : Message> EncoderFactory.protos() = this.protos(T::class)
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/local/LocalEncoderFactory.kt b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/local/LocalEncoderFactory.kt
index 76489cd..3c8098d 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/local/LocalEncoderFactory.kt
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/local/LocalEncoderFactory.kt
@@ -19,7 +19,6 @@
 import com.google.privacy.differentialprivacy.pipelinedp4j.core.Encoder
 import com.google.privacy.differentialprivacy.pipelinedp4j.core.EncoderFactory
 import com.google.protobuf.Message
-import kotlin.reflect.KClass
 
 class LocalEncoderFactory() : EncoderFactory {
   // The implementation of local encoders is empty because when the data is being processed
@@ -36,9 +35,9 @@
     return object : Encoder<Int> {}
   }
 
-  override fun <T : Any> records(recordClass: KClass<T>): Encoder<T> = object : Encoder<T> {}
+  override fun <T : Any> records(recordClass: Class<T>): Encoder<T> = object : Encoder<T> {}
 
-  override fun <T : Message> protos(protoClass: KClass<T>): Encoder<T> = object : Encoder<T> {}
+  override fun <T : Message> protos(protoClass: Class<T>): Encoder<T> = object : Encoder<T> {}
 
   override fun <T1 : Any, T2 : Any> tuple2sOf(first: Encoder<T1>, second: Encoder<T2>) =
     object : Encoder<Pair<T1, T2>> {}
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel
index 8092662..2efe020 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel
@@ -11,32 +11,34 @@
     srcs = ["SparkEncoders.kt"],
     deps = [
         "//main/com/google/privacy/differentialprivacy/pipelinedp4j/core:encoders",
-        "@maven//:com_google_protobuf_protobuf_java",
-        "@maven//:org_apache_spark_spark_core_2_13",
-        "@maven//:org_apache_spark_spark_sql_2_13",
-        "@maven//:org_apache_spark_spark_mllib_2_13",
-        "@maven//:org_apache_spark_spark_catalyst_2_13",
         "@maven//:com_fasterxml_jackson_core_jackson_databind",
         "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
+        "@maven//:com_google_protobuf_protobuf_java",
+        "@maven//:org_apache_spark_spark_catalyst_2_13",
+        "@maven//:org_apache_spark_spark_core_2_13",
+        "@maven//:org_apache_spark_spark_mllib_2_13",
+        "@maven//:org_apache_spark_spark_sql_2_13",
         "@maven//:org_scala_lang_scala_library",
     ],
 )
 
 kt_jvm_library(
     name = "spark_collections",
-    srcs = ["SparkCollection.kt",
-            "SparkTable.kt",],
+    srcs = [
+        "SparkCollection.kt",
+        "SparkTable.kt",
+    ],
     deps = [
         ":spark_encoders",
         "//main/com/google/privacy/differentialprivacy/pipelinedp4j/core:encoders",
         "//main/com/google/privacy/differentialprivacy/pipelinedp4j/core:framework_collections",
         "//main/com/google/privacy/differentialprivacy/pipelinedp4j/local:local_collections",
-         "@maven//:org_apache_spark_spark_core_2_13",
-         "@maven//:org_apache_spark_spark_sql_2_13",
-         "@maven//:org_apache_spark_spark_mllib_2_13",
-         "@maven//:org_apache_spark_spark_catalyst_2_13",
-         "@maven//:com_fasterxml_jackson_core_jackson_databind",
-         "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
+        "@maven//:com_fasterxml_jackson_core_jackson_databind",
+        "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
+        "@maven//:org_apache_spark_spark_catalyst_2_13",
+        "@maven//:org_apache_spark_spark_core_2_13",
+        "@maven//:org_apache_spark_spark_mllib_2_13",
+        "@maven//:org_apache_spark_spark_sql_2_13",
     ],
 )
 
diff --git a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncoders.kt b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncoders.kt
index 226cb3c..5d95010 100644
--- a/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncoders.kt
+++ b/pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncoders.kt
@@ -19,7 +19,6 @@
 import com.google.privacy.differentialprivacy.pipelinedp4j.core.Encoder
 import com.google.privacy.differentialprivacy.pipelinedp4j.core.EncoderFactory
 import com.google.protobuf.Message
-import kotlin.reflect.KClass
 import org.apache.spark.sql.Encoders
 
 /** A serializer and a deserializer for the data types to convert into Spark internal data types. */
@@ -39,12 +38,12 @@
     return SparkEncoder<Int>(Encoders.INT())
   }
 
-  override fun <T : Any> records(recordClass: KClass<T>): SparkEncoder<T> {
-    return SparkEncoder(Encoders.bean(recordClass.java))
+  override fun <T : Any> records(recordClass: Class<T>): SparkEncoder<T> {
+    return SparkEncoder(Encoders.bean(recordClass))
   }
 
-  override fun <T : Message> protos(protoClass: KClass<T>): SparkEncoder<T> {
-    return SparkEncoder<T>(Encoders.kryo(protoClass.java))
+  override fun <T : Message> protos(protoClass: Class<T>): SparkEncoder<T> {
+    return SparkEncoder<T>(Encoders.kryo(protoClass))
   }
 
   override fun <T1 : Any, T2 : Any> tuple2sOf(
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel
index 692eea9..edba452 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel
@@ -21,7 +21,7 @@
         "BeamQueryBuilderTest.kt",
         "BeamQueryTest.kt",
         "SparkQueryBuilderTest.kt",
-        "SparkQueryTest.kt"
+        "SparkQueryTest.kt",
     ],
     test_class = "com.google.privacy.differentialprivacy.pipelinedp4j.api.ApiTests",
     runtime_deps = [
@@ -32,15 +32,15 @@
         "//main/com/google/privacy/differentialprivacy/pipelinedp4j/core:dp_functions_params",
         "//main/com/google/privacy/differentialprivacy/pipelinedp4j/core/budget:budget_spec",
         "//tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark:spark_tests",
+        "@maven//:com_fasterxml_jackson_core_jackson_databind",
+        "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
         "@maven//:com_google_truth_truth",
         "@maven//:junit_junit",
         "@maven//:org_apache_beam_beam_sdks_java_core",
-        "@maven//:org_jetbrains_kotlin_kotlin_test",
-        "@maven//:org_apache_spark_spark_core_2_13",
-        "@maven//:org_apache_spark_spark_sql_2_13",
-        "@maven//:org_apache_spark_spark_mllib_2_13",
         "@maven//:org_apache_spark_spark_catalyst_2_13",
-        "@maven//:com_fasterxml_jackson_core_jackson_databind",
-        "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
+        "@maven//:org_apache_spark_spark_core_2_13",
+        "@maven//:org_apache_spark_spark_mllib_2_13",
+        "@maven//:org_apache_spark_spark_sql_2_13",
+        "@maven//:org_jetbrains_kotlin_kotlin_test",
     ],
 )
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryBuilderTest.kt b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryBuilderTest.kt
index e3e046b..5902b8e 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryBuilderTest.kt
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryBuilderTest.kt
@@ -1,3 +1,19 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 package com.google.privacy.differentialprivacy.pipelinedp4j.api
 
 import com.google.common.truth.Truth.assertThat
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryTest.kt b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryTest.kt
index 2d5b5c0..bffdb88 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryTest.kt
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkQueryTest.kt
@@ -1,3 +1,19 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 package com.google.privacy.differentialprivacy.pipelinedp4j.api
 
 import com.google.common.truth.Truth.assertThat
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BUILD.bazel b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BUILD.bazel
index b3734a9..28c0e27 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BUILD.bazel
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BUILD.bazel
@@ -37,6 +37,8 @@
         "@maven//:com_google_truth_truth",
         "@maven//:junit_junit",
         "@maven//:org_apache_beam_beam_sdks_java_core",
+        "@maven//:org_apache_beam_beam_sdks_java_extensions_avro",
+        "@maven//:org_apache_beam_beam_sdks_java_extensions_protobuf",
         "@maven//:org_hamcrest_hamcrest",
     ],
 )
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncodersTest.kt b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncodersTest.kt
index c76201d..7a1b9b2 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncodersTest.kt
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/beam/BeamEncodersTest.kt
@@ -16,6 +16,8 @@
 
 package com.google.privacy.differentialprivacy.pipelinedp4j.beam
 
+import org.apache.beam.sdk.extensions.protobuf.ProtoCoder
+import com.google.common.truth.Truth.assertThat
 import com.google.privacy.differentialprivacy.pipelinedp4j.core.ContributionWithPrivacyId
 import com.google.privacy.differentialprivacy.pipelinedp4j.core.contributionWithPrivacyId
 import com.google.privacy.differentialprivacy.pipelinedp4j.core.encoderOfContributionWithPrivacyId
@@ -25,6 +27,10 @@
 import com.google.privacy.differentialprivacy.pipelinedp4j.proto.quantilesAccumulator
 import com.google.privacy.differentialprivacy.pipelinedp4j.proto.sumAccumulator
 import com.google.protobuf.ByteString
+import org.apache.beam.sdk.coders.DoubleCoder
+import org.apache.beam.sdk.coders.StringUtf8Coder
+import org.apache.beam.sdk.coders.VarIntCoder
+import org.apache.beam.sdk.extensions.avro.coders.AvroCoder
 import org.apache.beam.sdk.testing.PAssert
 import org.apache.beam.sdk.testing.TestPipeline
 import org.apache.beam.sdk.transforms.Create
@@ -75,21 +81,9 @@
 
   @Test
   fun records_isPossibleToCreateBeamPCollectionOfThatType() {
-    val input =
-      listOf(
-        contributionWithPrivacyId("privacyId1", "partitionKey1", -1.0),
-        contributionWithPrivacyId("privacyId2", "partitionKey1", 0.0),
-        contributionWithPrivacyId("privacyId1", "partitionKey2", 1.0),
-        contributionWithPrivacyId("privacyId3", "partitionKey3", 1.2345),
-      )
+    val input = listOf(TestRecord("privacyId1", 1.0, -1), TestRecord("privacyId2", 2.0, 2))
     val inputCoder =
-      (encoderOfContributionWithPrivacyId(
-          beamEncoderFactory.strings(),
-          beamEncoderFactory.strings(),
-          beamEncoderFactory,
-        )
-          as BeamEncoder<ContributionWithPrivacyId<String, String>>)
-        .coder
+      (beamEncoderFactory.records(TestRecord::class) as BeamEncoder<TestRecord>).coder
 
     val pCollection = testPipeline.apply(Create.of(input).withCoder(inputCoder))
 
@@ -115,7 +109,9 @@
         },
         compoundAccumulator {},
       )
-    val inputCoder = beamEncoderFactory.protos(CompoundAccumulator::class).coder
+    val inputCoder =
+      (beamEncoderFactory.protos(CompoundAccumulator::class) as BeamEncoder<CompoundAccumulator>)
+        .coder
 
     val pCollection = testPipeline.apply(Create.of(input).withCoder(inputCoder))
 
@@ -137,6 +133,79 @@
     testPipeline.run().waitUntilFinish()
   }
 
+  @Test
+  fun contributionWithPrivacyIdOf_isPossibleToCreateBeamPCollectionOfThatType() {
+    val input =
+      listOf(
+        contributionWithPrivacyId("privacyId1", "partitionKey1", -1.0),
+        contributionWithPrivacyId("privacyId2", "partitionKey1", 0.0),
+        contributionWithPrivacyId("privacyId1", "partitionKey2", 1.0),
+        contributionWithPrivacyId("privacyId3", "partitionKey3", 1.2345),
+      )
+    val inputCoder =
+      (encoderOfContributionWithPrivacyId(
+          beamEncoderFactory.strings(),
+          beamEncoderFactory.strings(),
+          beamEncoderFactory,
+        )
+          as BeamEncoder<ContributionWithPrivacyId<String, String>>)
+        .coder
+
+    val pCollection = testPipeline.apply(Create.of(input).withCoder(inputCoder))
+
+    PAssert.that(pCollection).containsInAnyOrder(input)
+
+    testPipeline.run().waitUntilFinish()
+  }
+
+  @Test
+  fun recordsOfUnknownClass_string_createsEncoderWithStringCoder() {
+    @Suppress("UNCHECKED_CAST")
+    val encoder = beamEncoderFactory.recordsOfUnknownClass(String::class) as BeamEncoder<String>
+
+    assertThat(encoder.coder).isInstanceOf(StringUtf8Coder::class.java)
+  }
+
+  @Test
+  fun recordsOfUnknownClass_double_createsEncoderWithDoubleCoder() {
+    @Suppress("UNCHECKED_CAST")
+    val encoder = beamEncoderFactory.recordsOfUnknownClass(Double::class) as BeamEncoder<Double>
+
+    assertThat(encoder.coder).isInstanceOf(DoubleCoder::class.java)
+  }
+
+  @Test
+  fun recordsOfUnknownClass_int_createsEncoderWithIntCoder() {
+    @Suppress("UNCHECKED_CAST")
+    val encoder = beamEncoderFactory.recordsOfUnknownClass(Int::class) as BeamEncoder<Int>
+
+    assertThat(encoder.coder).isInstanceOf(VarIntCoder::class.java)
+  }
+
+  @Test
+  fun recordsOfUnknownClass_kotlinClass_createsEncoderWithAvroCoder() {
+    @Suppress("UNCHECKED_CAST")
+    val encoder =
+      beamEncoderFactory.recordsOfUnknownClass(TestRecord::class) as BeamEncoder<TestRecord>
+
+    assertThat(encoder.coder).isInstanceOf(AvroCoder::class.java)
+  }
+
+  @Test
+  fun recordsOfUnknownClass_proto_createsEncoderWithProtoCoder() {
+    @Suppress("UNCHECKED_CAST")
+    val encoder =
+      beamEncoderFactory.recordsOfUnknownClass(CompoundAccumulator::class)
+        as BeamEncoder<CompoundAccumulator>
+
+    assertThat(encoder.coder).isInstanceOf(ProtoCoder::class.java)
+  }
+
+  private data class TestRecord(val string: String, val double: Double, val int: Int) {
+    // Required for Beam serialization.
+    private constructor() : this("", 0.0, 0)
+  }
+
   companion object {
     private val beamEncoderFactory = BeamEncoderFactory()
   }
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel
index 804cd08..c1cef1f 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/BUILD.bazel
@@ -31,15 +31,15 @@
         "//main/com/google/privacy/differentialprivacy/pipelinedp4j/proto:accumulators_kt_proto",
         "//main/com/google/privacy/differentialprivacy/pipelinedp4j/spark:spark_collections",
         "//main/com/google/privacy/differentialprivacy/pipelinedp4j/spark:spark_encoders",
+        "@maven//:com_fasterxml_jackson_core_jackson_databind",
+        "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
         "@maven//:com_google_protobuf_protobuf_java",
         "@maven//:com_google_testparameterinjector_test_parameter_injector",
         "@maven//:com_google_truth_truth",
         "@maven//:junit_junit",
-        "@maven//:org_apache_spark_spark_core_2_13",
-        "@maven//:org_apache_spark_spark_sql_2_13",
-        "@maven//:org_apache_spark_spark_mllib_2_13",
         "@maven//:org_apache_spark_spark_catalyst_2_13",
-        "@maven//:com_fasterxml_jackson_core_jackson_databind",
-        "@maven//:com_fasterxml_jackson_module_jackson_module_paranamer",
+        "@maven//:org_apache_spark_spark_core_2_13",
+        "@maven//:org_apache_spark_spark_mllib_2_13",
+        "@maven//:org_apache_spark_spark_sql_2_13",
     ],
 )
diff --git a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncodersTest.kt b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncodersTest.kt
index 420008e..1267e62 100644
--- a/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncodersTest.kt
+++ b/pipelinedp4j/tests/com/google/privacy/differentialprivacy/pipelinedp4j/spark/SparkEncodersTest.kt
@@ -26,6 +26,7 @@
 import com.google.privacy.differentialprivacy.pipelinedp4j.proto.quantilesAccumulator
 import com.google.privacy.differentialprivacy.pipelinedp4j.proto.sumAccumulator
 import com.google.protobuf.ByteString
+import org.apache.spark.sql.Encoders
 import org.junit.ClassRule
 import org.junit.Test
 import org.junit.runner.RunWith
@@ -38,7 +39,9 @@
   fun strings_isPossibleToCreateSparkCollectionOfThatType() {
     val input = listOf("a", "b", "c")
     val inputCoder = sparkEncoderFactory.strings().encoder
+
     val dataset = sparkSession.spark.createDataset(input, inputCoder)
+
     assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
   }
 
@@ -46,7 +49,9 @@
   fun doubles_isPossibleToCreateSparkCollectionOfThatType() {
     val input = listOf(-1.2, 0.0, 2.1)
     val inputCoder = sparkEncoderFactory.doubles().encoder
+
     val dataset = sparkSession.spark.createDataset(input, inputCoder)
+
     assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
   }
 
@@ -54,28 +59,20 @@
   fun ints_isPossibleToCreateSparkCollectionOfThatType() {
     val input = listOf(-1, 0, 1)
     val inputCoder = sparkEncoderFactory.ints().encoder
+
     val dataset = sparkSession.spark.createDataset(input, inputCoder)
+
     assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
   }
 
   @Test
   fun records_isPossibleToCreateSparkCollectionOfThatType() {
-    val input =
-      listOf(
-        contributionWithPrivacyId("privacyId1", "partitionKey1", -1.0),
-        contributionWithPrivacyId("privacyId2", "partitionKey1", 0.0),
-        contributionWithPrivacyId("privacyId1", "partitionKey2", 1.0),
-        contributionWithPrivacyId("privacyId3", "partitionKey3", 1.2345),
-      )
+    val input = listOf(TestRecord("privacyId1", 1.0, -1), TestRecord("privacyId2", 2.0, 2))
     val inputCoder =
-      (encoderOfContributionWithPrivacyId(
-          sparkEncoderFactory.strings(),
-          sparkEncoderFactory.strings(),
-          sparkEncoderFactory,
-        )
-          as SparkEncoder<ContributionWithPrivacyId<String, String>>)
-        .encoder
+      (sparkEncoderFactory.records(TestRecord::class) as SparkEncoder<TestRecord>).encoder
+
     val dataset = sparkSession.spark.createDataset(input, inputCoder)
+
     assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
   }
 
@@ -96,8 +93,12 @@
         },
         compoundAccumulator {},
       )
-    val inputCoder = sparkEncoderFactory.protos(CompoundAccumulator::class).encoder
+    val inputCoder =
+      (sparkEncoderFactory.protos(CompoundAccumulator::class) as SparkEncoder<CompoundAccumulator>)
+        .encoder
+
     val dataset = sparkSession.spark.createDataset(input, inputCoder)
+
     assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
   }
 
@@ -110,6 +111,7 @@
         .encoder
 
     val dataset = sparkSession.spark.createDataset(input, inputEncoder)
+
     assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
   }
 
@@ -129,10 +131,80 @@
           sparkEncoderFactory.strings(),
         )
         .encoder
+
     val dataset = sparkSession.spark.createDataset(input, inputEncoder)
+
     assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
   }
 
+  @Test
+  fun contributionWithPrivacyIdOf_isPossibleToCreateSparkCollectionOfThatType() {
+    val input =
+      listOf(
+        contributionWithPrivacyId("privacyId1", "partitionKey1", -1.0),
+        contributionWithPrivacyId("privacyId2", "partitionKey1", 0.0),
+        contributionWithPrivacyId("privacyId1", "partitionKey2", 1.0),
+        contributionWithPrivacyId("privacyId3", "partitionKey3", 1.2345),
+      )
+    val inputCoder =
+      (encoderOfContributionWithPrivacyId(
+          sparkEncoderFactory.strings(),
+          sparkEncoderFactory.strings(),
+          sparkEncoderFactory,
+        )
+          as SparkEncoder<ContributionWithPrivacyId<String, String>>)
+        .encoder
+
+    val dataset = sparkSession.spark.createDataset(input, inputCoder)
+
+    assertThat(dataset.collectAsList()).containsExactlyElementsIn(input)
+  }
+
+  @Test
+  fun recordsOfUnknownClass_string_createsEncoderWithSparkStringEncoder() {
+    @Suppress("UNCHECKED_CAST")
+    val encoder = sparkEncoderFactory.recordsOfUnknownClass(String::class) as SparkEncoder<String>
+
+    assertThat(encoder.encoder).isEqualTo(Encoders.STRING())
+  }
+
+  @Test
+  fun recordsOfUnknownClass_double_createsEncoderWithSparkDoubleEncoder() {
+    @Suppress("UNCHECKED_CAST")
+    val encoder = sparkEncoderFactory.recordsOfUnknownClass(Double::class) as SparkEncoder<Double>
+
+    assertThat(encoder.encoder).isEqualTo(Encoders.DOUBLE())
+  }
+
+  @Test
+  fun recordsOfUnknownClass_int_createsEncoderWithSparkIntEncoder() {
+    @Suppress("UNCHECKED_CAST")
+    val encoder = sparkEncoderFactory.recordsOfUnknownClass(Int::class) as SparkEncoder<Int>
+
+    assertThat(encoder.encoder).isEqualTo(Encoders.INT())
+  }
+
+  @Test
+  fun recordsOfUnknownClass_kotlinClass_createsEncoderWithSparkBeanEncoder() {
+    @Suppress("UNCHECKED_CAST")
+    val encoder =
+      sparkEncoderFactory.recordsOfUnknownClass(TestRecord::class) as SparkEncoder<TestRecord>
+
+    assertThat(encoder.encoder).isEqualTo(Encoders.bean(TestRecord::class.java))
+  }
+
+  @Test
+  fun recordsOfUnknownClass_proto_createsEncoderWithSparkKryoEncoder() {
+    @Suppress("UNCHECKED_CAST")
+    val encoder =
+      sparkEncoderFactory.recordsOfUnknownClass(CompoundAccumulator::class)
+        as SparkEncoder<CompoundAccumulator>
+
+    assertThat(encoder.encoder).isEqualTo(Encoders.kryo(CompoundAccumulator::class.java))
+  }
+
+  data class TestRecord(var string: String = "", var double: Double = 0.0, var int: Int = 0)
+
   companion object {
     @JvmField @ClassRule val sparkSession = SparkSessionRule()
     private val sparkEncoderFactory = SparkEncoderFactory()
diff --git a/privacy-on-beam/pbeam/public_partitions.go b/privacy-on-beam/pbeam/public_partitions.go
index f95f6ae..145d0ac 100644
--- a/privacy-on-beam/pbeam/public_partitions.go
+++ b/privacy-on-beam/pbeam/public_partitions.go
@@ -22,12 +22,19 @@
 	"bytes"
 	"encoding/base64"
 	"fmt"
+	"math/rand"
 	"reflect"
 
+	"flag"
 	"github.com/google/differential-privacy/privacy-on-beam/v3/internal/kv"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/register"
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/filter"
+)
+
+var (
+	enableShardedPublicPartitions = flag.Bool("enable_sharded_public_partitions", false, "Enable sharded public partitions. This is a temporary flag to allow us to test the new sharded implementation of public partition filtering.")
 )
 
 func init() {
@@ -51,6 +58,11 @@
 	register.Function4x0[beam.W, func(*beam.V) bool, func(*beam.V) bool, func(beam.W, beam.V)](mergeResultWithEmptyPublicPartitionsFn)
 	register.Iter1[beam.V]()
 	register.Emitter2[beam.W, beam.V]()
+
+	register.Function1x2[kv.Pair, ShardedKey, []byte](addRandomShardIDFn)
+	register.Function4x0[[]byte, func(*[]byte) bool, func(*int) bool, func(ShardedKey, int)](extractEmittableKeysWithShardIDFn)
+	register.DoFn4x1[ShardedKey, func(*int) bool, func(*[]byte) bool, func(beam.T, beam.W), error](&filterKeysWithShardIDFn{})
+	register.Function1x2[kv.Pair, []byte, []byte](unwrapPairFn)
 }
 
 // newAddZeroValuesToPublicPartitionsFn turns a PCollection<V> into PCollection<V,0>.
@@ -138,6 +150,116 @@
 	}
 }
 
+// ShardedKey is an key encoded as bytes with a int shardID.
+type ShardedKey struct {
+	K       []byte
+	ShardID int
+}
+
+func addRandomShardIDFn(encoded kv.Pair) (ShardedKey, []byte) {
+	return ShardedKey{K: encoded.K, ShardID: rand.Intn(2048)}, encoded.V
+}
+
+func extractEmittableKeysWithShardIDFn(k []byte, isAllowedKeyIter func(*[]byte) bool, shardIDIter func(*int) bool, emit func(ShardedKey, int)) {
+	var isAllowedKey []byte
+	if !isAllowedKeyIter(&isAllowedKey) {
+		// k is not an allow listed key, filter it out.
+		return
+	}
+
+	var subkey int
+	for shardIDIter(&subkey) {
+		emit(ShardedKey{K: k, ShardID: subkey}, 0)
+	}
+}
+
+type filterKeysWithShardIDFn struct {
+	KType     beam.EncodedType
+	VType     beam.EncodedType
+	PairCodec *kv.Codec
+}
+
+func newFilterKeysWithShardIDFn(kType, vType reflect.Type) *filterKeysWithShardIDFn {
+	return &filterKeysWithShardIDFn{
+		KType: beam.EncodedType{T: kType},
+		VType: beam.EncodedType{T: vType},
+	}
+}
+
+func (fn *filterKeysWithShardIDFn) Setup() error {
+	fn.PairCodec = kv.NewCodec(fn.KType.T, fn.VType.T)
+	return fn.PairCodec.Setup()
+}
+
+func (fn *filterKeysWithShardIDFn) ProcessElement(k ShardedKey, isEmittableShardIDIter func(*int) bool, pcolValueIter func(*[]byte) bool, emit func(beam.T, beam.W)) error {
+	var isEmittableShardID int
+	if !isEmittableShardIDIter(&isEmittableShardID) {
+		// k isn't a key from the public partitions collection.
+		return nil
+	}
+
+	var pcolValue []byte
+	for pcolValueIter(&pcolValue) {
+		k, v, err := fn.PairCodec.Decode(kv.Pair{K: k.K, V: pcolValue})
+		if err != nil {
+			return err
+		}
+		emit(k, v)
+	}
+
+	return nil
+}
+
+func unwrapPairFn(encoded kv.Pair) ([]byte, []byte) {
+	return encoded.K, encoded.V
+}
+
+func unwrapShardedKeyFn(shardedKey ShardedKey) ([]byte, int) {
+	return shardedKey.K, shardedKey.ShardID
+}
+
+// Filters out KV-s from col that have a key in 'keys'.
+//
+// A single key in col may have a huge number of values. This function handles that
+// case by sharding col randomly into 2048 collections, and then joining the keys
+// within each of those shards. This reduces stragglers when there is a hot key in col,
+// since its processing is parallelized 2048 ways.
+//
+// Each value in col is randomly selected to be in one of the 2048 shards. Then the
+// 'keys' collection is joined with the sharded col collection to find the shardIds
+// that need to be present per key. Then the sharded col collection is joined with
+// the sharded 'keys' collection to compute the final filtered result.
+func filterKeysImbalanced(s beam.Scope, col beam.PCollection, keys beam.PCollection) beam.PCollection {
+	s = s.Scope("filterKeysImbalanced")
+
+	kT, vT := beam.ValidateKVType(col)
+
+	// Add a random shardId (one of 2048) to each element in col.
+	// PCollection<KV<ShardedKey[key, randInt[0;2048)], []byte>>
+	pcolAsBytesWithShardID := beam.ParDo(s, addRandomShardIDFn, beam.ParDo(s, kv.NewEncodeFn(kT, vT), col))
+	// Drop values and remove duplicates.
+	// PCollection<ShardedKey[key, randInt[0;2048)]>
+	uniqueKeysWithShardID := filter.Distinct(s, beam.DropValue(s, pcolAsBytesWithShardID))
+
+	// Prepare the keys for a CoGroupBy with uniqueSubkeys.
+	// PCollection<KV<key, 0>>
+	keysWithZero := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, keys)
+	// PCollection<kv.Pair<key, 0>>
+	keysEncodedWithZero := beam.ParDo(
+		s,
+		kv.NewEncodeFn(keysWithZero.Type().Components()[0],
+			keysWithZero.Type().Components()[1]), keysWithZero)
+
+	// Find the shardIds per key in the keys collection.
+	// PCollection<KV<key, *>>
+	groupedByKey := beam.CoGroupByKey(s, beam.ParDo(s, unwrapPairFn, keysEncodedWithZero), beam.ParDo(s, unwrapShardedKeyFn, uniqueKeysWithShardID))
+	emittableKeysWithShardID := beam.ParDo(s, extractEmittableKeysWithShardIDFn, groupedByKey)
+
+	// Finally perform the sharded filter.
+	groupedByKeyAndShardID := beam.CoGroupByKey(s, emittableKeysWithShardID, pcolAsBytesWithShardID)
+	return beam.ParDo(s, newFilterKeysWithShardIDFn(kT.Type(), vT.Type()), groupedByKeyAndShardID, beam.TypeDefinition{Var: beam.TType, T: kT.Type()}, beam.TypeDefinition{Var: beam.WType, T: vT.Type()})
+}
+
 // dropNonPublicPartitionsVFn drops partitions not specified in
 // PublicPartitions from pcol. It can be used for aggregations on V values,
 // e.g. Count and DistinctPrivacyID.
@@ -152,9 +274,13 @@
 // Returns a PCollection<PrivacyKey, Value> only for values present in
 // publicPartitions.
 func dropNonPublicPartitionsVFn(s beam.Scope, publicPartitions beam.PCollection, pcol PrivatePCollection) beam.PCollection {
-	publicPartitionsWithZeros := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, publicPartitions)
-	groupedByValue := beam.CoGroupByKey(s, publicPartitionsWithZeros, beam.SwapKV(s, pcol.col))
-	return beam.ParDo(s, mergePublicValues, groupedByValue)
+	if *enableShardedPublicPartitions {
+		return beam.SwapKV(s, filterKeysImbalanced(s, beam.SwapKV(s, pcol.col), publicPartitions))
+	} else {
+		publicPartitionsWithZeros := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, publicPartitions)
+		groupedByValue := beam.CoGroupByKey(s, publicPartitionsWithZeros, beam.SwapKV(s, pcol.col))
+		return beam.ParDo(s, mergePublicValues, groupedByValue)
+	}
 }
 
 // dropNonPublicPartitionsKVFn drops partitions not specified in
@@ -175,12 +301,20 @@
 // Returns a PCollection<PrivacyKey, <PartitionKey, Value>> only for values present in
 // publicPartitions.
 func dropNonPublicPartitionsKVFn(s beam.Scope, publicPartitions beam.PCollection, pcol PrivatePCollection, idType typex.FullType) beam.PCollection {
-	publicPartitionsWithZeros := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, publicPartitions)
-	encodedIDV := beam.ParDo(s, newEncodeIDVFn(idType, pcol.codec), pcol.col, beam.TypeDefinition{Var: beam.WType, T: pcol.codec.KType.T})
-	groupedByValue := beam.CoGroupByKey(s, publicPartitionsWithZeros, encodedIDV)
-	merged := beam.SwapKV(s, beam.ParDo(s, mergePublicValues, groupedByValue))
-	decodeFn := newDecodeIDVFn(pcol.codec.KType, kv.NewCodec(idType.Type(), pcol.codec.VType.T))
-	return beam.ParDo(s, decodeFn, merged, beam.TypeDefinition{Var: beam.UType, T: idType.Type()})
+	if *enableShardedPublicPartitions {
+		encodedIDV := beam.ParDo(
+			s, newEncodeIDVFn(idType, pcol.codec), pcol.col, beam.TypeDefinition{Var: beam.WType, T: pcol.codec.KType.T})
+		filteredEncodedIDV := filterKeysImbalanced(s, encodedIDV, publicPartitions)
+		decodeFn := newDecodeIDVFn(pcol.codec.KType, kv.NewCodec(idType.Type(), pcol.codec.VType.T))
+		return beam.ParDo(s, decodeFn, filteredEncodedIDV, beam.TypeDefinition{Var: beam.UType, T: idType.Type()})
+	} else {
+		publicPartitionsWithZeros := beam.ParDo(s, addZeroValuesToPublicPartitionsInt64, publicPartitions)
+		encodedIDV := beam.ParDo(s, newEncodeIDVFn(idType, pcol.codec), pcol.col, beam.TypeDefinition{Var: beam.WType, T: pcol.codec.KType.T})
+		groupedByValue := beam.CoGroupByKey(s, publicPartitionsWithZeros, encodedIDV)
+		merged := beam.SwapKV(s, beam.ParDo(s, mergePublicValues, groupedByValue))
+		decodeFn := newDecodeIDVFn(pcol.codec.KType, kv.NewCodec(idType.Type(), pcol.codec.VType.T))
+		return beam.ParDo(s, decodeFn, merged, beam.TypeDefinition{Var: beam.UType, T: idType.Type()})
+	}
 }
 
 // encodeIDVFn takes a PCollection<ID,kv.Pair{K,V}> as input, and returns a
diff --git a/privacy-on-beam/pbeam/public_partitions_test.go b/privacy-on-beam/pbeam/public_partitions_test.go
index 54e83e1..1ed8a00 100644
--- a/privacy-on-beam/pbeam/public_partitions_test.go
+++ b/privacy-on-beam/pbeam/public_partitions_test.go
@@ -22,6 +22,7 @@
 	"reflect"
 	"testing"
 
+	"flag"
 	"github.com/google/differential-privacy/privacy-on-beam/v3/pbeam/testutils"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
@@ -104,6 +105,40 @@
 	}
 }
 
+// TODO: Remove once the enable_sharded_public_partitions flag is gone.
+func TestDropNonPublicPartitionsVFnShardedImpl(t *testing.T) {
+	flag.Set("enable_sharded_public_partitions", "true")
+
+	pairs := testutils.ConcatenatePairs(
+		testutils.MakePairsWithFixedV(7, 0),
+		testutils.MakePairsWithFixedVStartingFromKey(7, 10, 1),
+		testutils.MakePairsWithFixedVStartingFromKey(17, 83, 2),
+		testutils.MakePairsWithFixedVStartingFromKey(100, 10, 3),
+	)
+
+	// Keep partitions 0, 2;
+	// drop partitions 1, 3.
+	result := testutils.ConcatenatePairs(
+		testutils.MakePairsWithFixedV(7, 0),
+		testutils.MakePairsWithFixedVStartingFromKey(17, 83, 2),
+	)
+
+	p, s, col, want := ptest.CreateList2(pairs, result)
+	want = beam.ParDo(s, testutils.PairToKV, want)
+	col = beam.ParDo(s, testutils.PairToKV, col)
+	partitions := []int{0, 2}
+
+	partitionsCol := beam.CreateList(s, partitions)
+	epsilon := 50.0
+	pcol := MakePrivate(s, col, privacySpec(t, PrivacySpecParams{AggregationEpsilon: epsilon}))
+	got := dropNonPublicPartitionsVFn(s, partitionsCol, pcol)
+	testutils.EqualsKVInt(t, s, got, want)
+	if err := ptest.Run(p); err != nil {
+		t.Errorf("DropNonPublicPartitionsVFn did not drop non public partitions as expected: %v", err)
+	}
+	flag.Set("enable_sharded_public_partitions", "false")
+}
+
 // TestDropNonPublicPartitionsKVFn checks that int elements with non-public partitions
 // are dropped (tests function used for sum and mean).
 func TestDropNonPublicPartitionsKVFn(t *testing.T) {
@@ -147,6 +182,51 @@
 	}
 }
 
+// TODO: Remove once the enable_sharded_public_partitions flag is gone.
+func TestDropNonPublicPartitionsKVFnShardedImpl(t *testing.T) {
+	flag.Set("enable_sharded_public_partitions", "true")
+
+	triples := testutils.ConcatenateTriplesWithIntValue(
+		testutils.MakeTripleWithIntValueStartingFromKey(0, 7, 0, 0),
+		testutils.MakeTripleWithIntValueStartingFromKey(7, 3, 1, 0),
+		testutils.MakeTripleWithIntValueStartingFromKey(10, 90, 2, 0),
+		testutils.MakeTripleWithIntValueStartingFromKey(100, 100, 11, 0),
+		testutils.MakeTripleWithIntValueStartingFromKey(200, 5, 12, 0))
+	// Keep partitions 0, 2.
+	// Drop partitions 1, 33, 100.
+	result := testutils.ConcatenateTriplesWithIntValue(
+		testutils.MakeTripleWithIntValueStartingFromKey(0, 7, 0, 0),
+		testutils.MakeTripleWithIntValueStartingFromKey(10, 90, 2, 0))
+
+	p, s, col, col2 := ptest.CreateList2(triples, result)
+	// Doesn't matter that the values 3, 4, 5, 6, 9, 10
+	// are in the partitions PCollection because we are
+	// just dropping the values that are in our original PCollection
+	// that are not in public partitions.
+	partitionsCol := beam.CreateList(s, []int{0, 2, 3, 4, 5, 6, 9, 10})
+	col = beam.ParDo(s, testutils.ExtractIDFromTripleWithIntValue, col)
+	col2 = beam.ParDo(s, testutils.ExtractIDFromTripleWithIntValue, col2)
+	epsilon := 50.0
+
+	pcol := MakePrivate(s, col, privacySpec(t, PrivacySpecParams{AggregationEpsilon: epsilon}))
+	pcol = ParDo(s, testutils.TripleWithIntValueToKV, pcol)
+	idT, _ := beam.ValidateKVType(pcol.col)
+
+	got := dropNonPublicPartitionsKVFn(s, partitionsCol, pcol, idT)
+	got = beam.SwapKV(s, got)
+
+	pcol2 := MakePrivate(s, col2, privacySpec(t, PrivacySpecParams{AggregationEpsilon: epsilon}))
+	pcol2 = ParDo(s, testutils.TripleWithIntValueToKV, pcol2)
+	want := pcol2.col
+	want = beam.SwapKV(s, want)
+
+	testutils.EqualsKVInt(t, s, got, want)
+	if err := ptest.Run(p); err != nil {
+		t.Errorf("TestDropNonPublicPartitionsKVFn did not drop non public partitions as expected: %v", err)
+	}
+	flag.Set("enable_sharded_public_partitions", "false")
+}
+
 // Check that float elements with non-public partitions
 // are dropped (tests function used for sum and mean).
 func TestDropNonPublicPartitionsFloat(t *testing.T) {
@@ -190,3 +270,49 @@
 		t.Errorf("TestDropNonPublicPartitionsFloat did not drop non public partitions as expected: %v", err)
 	}
 }
+
+// TODO: Remove once the enable_sharded_public_partitions flag is gone.
+func TestDropNonPublicPartitionsFloatShardedImpl(t *testing.T) {
+	flag.Set("enable_sharded_public_partitions", "true")
+
+	// In this test, we check  that non-public partitions
+	// are dropped. This function is used for sum and mean.
+	// Used example values from the mean test.
+	triples := testutils.ConcatenateTriplesWithFloatValue(
+		testutils.MakeTripleWithFloatValue(7, 0, 2.0),
+		testutils.MakeTripleWithFloatValueStartingFromKey(7, 100, 1, 1.3),
+		testutils.MakeTripleWithFloatValueStartingFromKey(107, 150, 1, 2.5),
+	)
+	// Keep partition 0.
+	// drop partition 1.
+	result := testutils.ConcatenateTriplesWithFloatValue(
+		testutils.MakeTripleWithFloatValue(7, 0, 2.0))
+
+	p, s, col, col2 := ptest.CreateList2(triples, result)
+
+	// Doesn't matter that the values 2, 3, 4, 5, 6, 7 are in the partitions PCollection.
+	// We are just dropping the values that are in our original PCollection that are not in
+	// public partitions.
+	partitionsCol := beam.CreateList(s, []int{0, 2, 3, 4, 5, 6, 7})
+	col = beam.ParDo(s, testutils.ExtractIDFromTripleWithFloatValue, col)
+	col2 = beam.ParDo(s, testutils.ExtractIDFromTripleWithFloatValue, col2)
+	epsilon := 50.0
+
+	pcol := MakePrivate(s, col, privacySpec(t, PrivacySpecParams{AggregationEpsilon: epsilon}))
+	pcol = ParDo(s, testutils.TripleWithFloatValueToKV, pcol)
+	idT, _ := beam.ValidateKVType(pcol.col)
+
+	got := dropNonPublicPartitionsKVFn(s, partitionsCol, pcol, idT)
+	got = beam.SwapKV(s, got)
+
+	pcol2 := MakePrivate(s, col2, privacySpec(t, PrivacySpecParams{AggregationEpsilon: epsilon}))
+	pcol2 = ParDo(s, testutils.TripleWithFloatValueToKV, pcol2)
+	want := pcol2.col
+	want = beam.SwapKV(s, want)
+
+	testutils.EqualsKVInt(t, s, got, want)
+	if err := ptest.Run(p); err != nil {
+		t.Errorf("TestDropNonPublicPartitionsFloat did not drop non public partitions as expected: %v", err)
+	}
+	flag.Set("enable_sharded_public_partitions", "false")
+}
diff --git a/privacy-on-beam/pbeam/testutils/testutils.go b/privacy-on-beam/pbeam/testutils/testutils.go
index 04851e6..8151f2d 100644
--- a/privacy-on-beam/pbeam/testutils/testutils.go
+++ b/privacy-on-beam/pbeam/testutils/testutils.go
@@ -24,6 +24,7 @@
 	"math"
 	"math/big"
 	"reflect"
+	"sort"
 	"testing"
 
 	"github.com/google/differential-privacy/go/v3/dpagg"
@@ -510,6 +511,7 @@
 	for vIter(&v) {
 		vSlice = append(vSlice, float64(v))
 	}
+	sort.Float64s(vSlice)
 	return vSlice
 }