blob: ad4f95c6e3e8eec3abf70f5872e51c24f1547186 [file]
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.privacy.differentialprivacy.pipelinedp4j.beam
import com.google.common.truth.Truth.assertThat
import org.apache.beam.sdk.coders.VarIntCoder
import org.apache.beam.sdk.testing.PAssert
import org.apache.beam.sdk.testing.TestPipeline
import org.apache.beam.sdk.transforms.Create
import org.apache.beam.sdk.values.KV
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
@RunWith(JUnit4::class)
class BeamCollectionTest {
@get:Rule val testPipeline: TestPipeline = TestPipeline.create()
@Test
fun elementsEncoder_returnsCorrectEncoder() {
val pCollection = testPipeline.apply(Create.of<Int>(listOf()).withCoder(VarIntCoder.of()))
val beamCollection = BeamCollection(pCollection)
val result = beamCollection.elementsEncoder
testPipeline.run().waitUntilFinish()
assertThat(result).isInstanceOf(BeamEncoder::class.java)
assertThat(result.coder).isEqualTo(VarIntCoder.of())
}
@Test
fun distinct_removesDuplicates() {
val pCollection = testPipeline.apply(Create.of(listOf(1, 2, 1)).withCoder(VarIntCoder.of()))
val beamCollection = BeamCollection(pCollection)
val result: BeamCollection<Int> = beamCollection.distinct("stageName")
PAssert.that(result.data).containsInAnyOrder(1, 2)
testPipeline.run().waitUntilFinish()
}
@Test
fun map_appliesMapFn() {
val pCollection = testPipeline.apply(Create.of(listOf(1)).withCoder(VarIntCoder.of()))
val beamCollection = BeamCollection(pCollection)
val result: BeamCollection<String> =
beamCollection.map("Test", beamEncoderFactory.strings(), { v -> v.toString() })
PAssert.that(result.data).containsInAnyOrder("1")
testPipeline.run().waitUntilFinish()
}
@Test
fun keyBy_keysCollection() {
val pCollection = testPipeline.apply(Create.of(listOf(1)).withCoder(VarIntCoder.of()))
val beamCollection = BeamCollection(pCollection)
val result: BeamTable<String, Int> =
beamCollection.keyBy("Test", beamEncoderFactory.strings(), { v -> v.toString() })
PAssert.that(result.data).containsInAnyOrder(KV.of("1", 1))
testPipeline.run().waitUntilFinish()
}
@Test
fun mapToTable_appliesMapFn() {
val pCollection = testPipeline.apply(Create.of(listOf(1)).withCoder(VarIntCoder.of()))
val beamCollection = BeamCollection(pCollection)
val result: BeamTable<String, Int> =
beamCollection.mapToTable(
"Test",
beamEncoderFactory.strings(),
beamEncoderFactory.ints(),
{ v -> Pair(v.toString(), v) },
)
PAssert.that(result.data).containsInAnyOrder(KV.of("1", 1))
testPipeline.run().waitUntilFinish()
}
companion object {
private val beamEncoderFactory = BeamEncoderFactory()
}
}