blob: 30686e9d683b2b1d5198a0b8547e6f1926ed643b [file] [log] [blame]
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from absl.testing import parameterized
import attrs
import tree
from dp_accounting import dp_event
def assert_not_contains_attrs(structure):
def _fn(structure):
if attrs.has(type(structure)):
raise AssertionError(
'Expected structure to not contain `attrs` decorated classes, '
f'found {structure}.'
)
return None
tree.traverse(_fn, structure)
def assert_not_contains_named_tuples(structure):
def _fn(structure):
if isinstance(structure, dp_event.DpEventNamedTuple):
raise AssertionError(
'Expected structure to not contain `dp_event.DpEventNamedTuple`s, '
f'found {structure}.'
)
return None
tree.traverse(_fn, structure)
class DpEventTest(parameterized.TestCase):
@parameterized.named_parameters(
('base_class', dp_event.DpEvent()),
('no_op', dp_event.NoOpDpEvent()),
('non_private', dp_event.NonPrivateDpEvent()),
('unsupported', dp_event.UnsupportedDpEvent()),
('gaussian', dp_event.GaussianDpEvent(1.0)),
('laplace', dp_event.LaplaceDpEvent(1.0)),
(
'self_composed',
dp_event.SelfComposedDpEvent(dp_event.GaussianDpEvent(1.0), 10),
),
(
'composed',
dp_event.ComposedDpEvent(
[dp_event.GaussianDpEvent(1.0), dp_event.LaplaceDpEvent(1.0)]
),
),
(
'poisson',
dp_event.PoissonSampledDpEvent(0.1, dp_event.GaussianDpEvent(1.0)),
),
(
'sampled_with_replacement',
dp_event.SampledWithReplacementDpEvent(
1000, 10, dp_event.GaussianDpEvent(1.0)
),
),
(
'sampled_without_replacement',
dp_event.SampledWithoutReplacementDpEvent(
1000, 10, dp_event.GaussianDpEvent(1.0)
),
),
('tree_int', dp_event.SingleEpochTreeAggregationDpEvent(1.0, 5)),
('tree_list', dp_event.SingleEpochTreeAggregationDpEvent(1.0, [5, 10])),
(
'repeat_and_select',
dp_event.RepeatAndSelectDpEvent(
dp_event.GaussianDpEvent(1.0), 30.0, 1.0
),
),
(
'complex',
dp_event.ComposedDpEvent([
dp_event.SingleEpochTreeAggregationDpEvent(1.0, 5),
dp_event.PoissonSampledDpEvent(0.1, dp_event.LaplaceDpEvent(1.0)),
dp_event.SelfComposedDpEvent(
dp_event.SampledWithReplacementDpEvent(
1000, 10, dp_event.GaussianDpEvent(1.0)
),
50,
),
]),
),
(
'mixture_gaussian',
dp_event.MixtureOfGaussiansDpEvent(1.0, [0, 1, 2], [0.25, 0.5, 0.25]),
),
)
def test_to_from_named_tuple(self, event):
named_tuple = event.to_named_tuple()
self.assertIsInstance(named_tuple, tuple)
self.assertIsInstance(named_tuple, dp_event.DpEventNamedTuple)
assert_not_contains_attrs(named_tuple)
reconstructed = dp_event.DpEvent.from_named_tuple(named_tuple)
assert_not_contains_named_tuples(reconstructed)
self.assertEqual(event, reconstructed)
if __name__ == '__main__':
absltest.main()