blob: bffdb88ea0332316b2fb439c1b01b504514208bb [file] [log] [blame]
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.privacy.differentialprivacy.pipelinedp4j.api
import com.google.common.truth.Truth.assertThat
import com.google.privacy.differentialprivacy.pipelinedp4j.spark.SparkSessionRule
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Encoders
import org.junit.ClassRule
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
@RunWith(JUnit4::class)
class SparkQueryTest {
@Test
fun run_onePublicGroupTwoDifferentContributions_allPossibleAggregations_calculatesStatisticsCorrectly() {
val dataset =
sparkSession.spark.createDataset(
listOf(
Pair(Pair("group1", "pid1"), 1.0),
Pair(Pair("group1", "pid1"), 1.5),
Pair(Pair("group1", "pid2"), 2.0),
),
Encoders.kryo(Pair::class.java) as Encoder<Pair<Pair<String, String>, Double>>,
)
val publicGroups = sparkSession.spark.createDataset(listOf("group1"), Encoders.STRING())
val valueExtractor = { it: Pair<Pair<String, String>, Double> -> it.second }
val result: Dataset<QueryPerGroupResult> =
QueryBuilder.from(dataset, { it.first.second })
.groupBy(
{ it.first.first },
maxGroupsContributed = 1,
maxContributionsPerGroup = 2,
publicGroups,
)
.countDistinctPrivacyUnits("pid_cnt")
.count("cnt")
.sum(valueExtractor, outputColumnName = "sumResult")
.mean(valueExtractor, minValue = 1.0, maxValue = 2.0, "meanResult")
.variance(valueExtractor, minValue = 1.0, maxValue = 2.0, "varianceResult")
.quantiles(
valueExtractor,
ranks = listOf(0.5),
minValue = 1.0,
maxValue = 2.0,
"quantilesResult",
)
.build()
.run(TotalBudget(epsilon = 1000.0), NoiseKind.LAPLACE)
val output = result.collectAsList()
assertThat(output).run {
hasSize(1)
val queryPerGroupResult = output.iterator().next()
assertThat(queryPerGroupResult.groupKey).isEqualTo("group1")
assertThat(queryPerGroupResult.aggregationResults).hasSize(6)
assertThat(queryPerGroupResult.aggregationResults.keys)
.containsExactly(
"pid_cnt",
"cnt",
"sumResult",
"meanResult",
"varianceResult",
"quantilesResult_0.5",
)
assertThat(queryPerGroupResult.aggregationResults["pid_cnt"]).isWithin(0.5).of(2.0)
assertThat(queryPerGroupResult.aggregationResults["cnt"]).isWithin(0.5).of(3.0)
assertThat(queryPerGroupResult.aggregationResults["sumResult"]).isWithin(0.5).of(4.5)
assertThat(queryPerGroupResult.aggregationResults["meanResult"]).isWithin(0.5).of(1.5)
assertThat(queryPerGroupResult.aggregationResults["varianceResult"]).isWithin(0.05).of(0.16)
assertThat(queryPerGroupResult.aggregationResults["quantilesResult_0.5"])
.isWithin(0.5)
.of(1.5)
null
}
}
@Test
fun run_sumAndQuantiles_calculatesCorrectly() {
val dataset =
sparkSession.spark.createDataset(
listOf(
Pair(Pair("group1", "pid1"), 1.0),
Pair(Pair("group1", "pid1"), 1.5),
Pair(Pair("group1", "pid2"), 2.0),
),
Encoders.kryo(Pair::class.java) as Encoder<Pair<Pair<String, String>, Double>>,
)
val publicGroups = sparkSession.spark.createDataset(listOf("group1"), Encoders.STRING())
val valueExtractor = { it: Pair<Pair<String, String>, Double> -> it.second }
val result: Dataset<QueryPerGroupResult> =
QueryBuilder.from(dataset, { it.first.second })
.groupBy(
{ it.first.first },
maxGroupsContributed = 1,
maxContributionsPerGroup = 2,
publicGroups,
)
.sum(
valueExtractor,
minTotalValuePerPrivacyUnitInGroup = 2.0,
maxTotalValuePerPrivacyUnitInGroup = 2.5,
outputColumnName = "sumResult",
)
.quantiles(
valueExtractor,
ranks = listOf(0.5),
minValue = 1.0,
maxValue = 2.0,
"quantilesResult",
)
.build()
.run(TotalBudget(epsilon = 1000.0), NoiseKind.LAPLACE)
val output = result.collectAsList()
assertThat(output).run {
hasSize(1)
val queryPerGroupResult = output.iterator().next()
assertThat(queryPerGroupResult.groupKey).isEqualTo("group1")
assertThat(queryPerGroupResult.aggregationResults).hasSize(2)
assertThat(queryPerGroupResult.aggregationResults.keys)
.containsExactly("sumResult", "quantilesResult_0.5")
assertThat(queryPerGroupResult.aggregationResults["sumResult"]).isWithin(0.5).of(4.5)
assertThat(queryPerGroupResult.aggregationResults["quantilesResult_0.5"])
.isWithin(0.5)
.of(1.5)
null
}
}
companion object {
@JvmField @ClassRule val sparkSession = SparkSessionRule()
}
}