blob: ff6559f781dcf5036186cab9c058467b9ba901e2 [file] [log] [blame]
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.privacy.differentialprivacy.pipelinedp4j.beam
import com.google.privacy.differentialprivacy.pipelinedp4j.core.Encoder
import com.google.privacy.differentialprivacy.pipelinedp4j.core.EncoderFactory
import com.google.protobuf.Message
import java.io.InputStream
import java.io.OutputStream
import kotlin.reflect.KClass
import org.apache.beam.sdk.coders.Coder
import org.apache.beam.sdk.coders.CustomCoder
import org.apache.beam.sdk.coders.DoubleCoder
import org.apache.beam.sdk.coders.StringUtf8Coder
import org.apache.beam.sdk.coders.VarIntCoder
import org.apache.beam.sdk.extensions.avro.coders.AvroCoder
import org.apache.beam.sdk.extensions.protobuf.ProtoCoder
class BeamEncoder<T>(val coder: Coder<T>) : Encoder<T>
class BeamEncoderFactory() : EncoderFactory {
override fun strings() = BeamEncoder<String>(StringUtf8Coder.of())
override fun doubles() = BeamEncoder<Double>(DoubleCoder.of())
override fun ints() = BeamEncoder<Int>(VarIntCoder.of())
override fun <T : Any> records(recordClass: KClass<T>) =
BeamEncoder<T>(AvroCoder.of(recordClass.java))
override fun <T : Message> protos(protoClass: KClass<T>) =
BeamEncoder<T>(ProtoCoder.of(protoClass.java))
override fun <T1 : Any, T2 : Any> tuple2sOf(first: Encoder<T1>, second: Encoder<T2>) =
BeamEncoder(
KotlinPairCoder((first as BeamEncoder<T1>).coder, (second as BeamEncoder<T2>).coder)
)
}
private class KotlinPairCoder<FirstT, SecondT>(
private val firstCoder: Coder<FirstT>,
private val secondCoder: Coder<SecondT>,
) : CustomCoder<Pair<FirstT, SecondT>>() {
override fun encode(value: Pair<FirstT, SecondT>, out: OutputStream) {
firstCoder.encode(value.first, out)
secondCoder.encode(value.second, out)
}
override fun decode(inStream: InputStream): Pair<FirstT, SecondT> {
val first = firstCoder.decode(inStream)
val second = secondCoder.decode(inStream)
return Pair(first, second)
}
override fun verifyDeterministic() {
Coder.verifyDeterministic(
this,
"KotlinPairCoder is not deterministic",
listOf(firstCoder, secondCoder),
)
}
}