| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for input pipeline modifications for distribution strategies.""" |
| |
| import os |
| |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.data.ops import readers |
| from tensorflow.python.data.util import structure |
| from tensorflow.python.distribute import input_ops |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.lib.io import python_io |
| from tensorflow.python.ops import gen_dataset_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.platform import test |
| from tensorflow.python.util import compat |
| |
| |
| class AutoShardDatasetTest(test.TestCase): |
| |
| def setUp(self): |
| super(AutoShardDatasetTest, self).setUp() |
| self._num_files = 10 |
| self._num_records = 4 |
| self._num_shards = 2 |
| self._shard_index = 0 |
| self._record_bytes = 10 |
| |
| def _getNext(self, dataset): |
| if context.executing_eagerly(): |
| iterator = iter(dataset) |
| return iterator._next_internal # pylint: disable=protected-access |
| else: |
| iterator = dataset_ops.make_one_shot_iterator(dataset) |
| get_next = iterator.get_next() |
| return lambda: get_next |
| |
| def _record(self, r, f): |
| return compat.as_bytes("Record %d of file %d" % (r, f)) |
| |
| def _text_line(self, r, f): |
| return compat.as_bytes("Text line %d of file %d" % (r, f)) |
| |
| def _fixed_length_record(self, r, f): |
| return compat.as_bytes(str((r * f) % 10) * self._record_bytes) |
| |
| def _createTFRecordFiles(self): |
| filenames = [] |
| for i in range(self._num_files): |
| fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) |
| filenames.append(fn) |
| writer = python_io.TFRecordWriter(fn) |
| for j in range(self._num_records): |
| record = self._record(j, i) |
| writer.write(record) |
| writer.close() |
| return filenames |
| |
| def _createTextFiles(self): |
| filenames = [] |
| for i in range(self._num_files): |
| fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) |
| filenames.append(fn) |
| contents = [] |
| for j in range(self._num_records): |
| contents.append(self._text_line(j, i)) |
| if j + 1 != self._num_records or i == 0: |
| contents.append(b"\r\n") |
| contents = b"".join(contents) |
| |
| with open(fn, "wb") as f: |
| f.write(contents) |
| return filenames |
| |
| def _createFixedLengthRecordFiles(self): |
| filenames = [] |
| for i in range(self._num_files): |
| fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) |
| filenames.append(fn) |
| with open(fn, "wb") as f: |
| for j in range(self._num_records): |
| f.write(self._fixed_length_record(j, i)) |
| return filenames |
| |
| def _verifySimpleShardingOutput(self, dataset, record_fn): |
| next_element_fn = self._getNext(dataset) |
| with self.cached_session(): |
| for f in range(self._shard_index, self._num_files, self._num_shards): |
| for r in range(self._num_records): |
| self.assertAllEqual(record_fn(r, f), self.evaluate(next_element_fn())) |
| with self.assertRaises(errors.OutOfRangeError): |
| self.evaluate(next_element_fn()) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testTFRecordDataset(self): |
| dataset = readers.TFRecordDataset(self._createTFRecordFiles()) |
| dataset = input_ops.auto_shard_dataset( |
| dataset, self._num_shards, self._shard_index) |
| |
| self._verifySimpleShardingOutput(dataset, self._record) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testFlatMap(self): |
| dataset = dataset_ops.Dataset.from_tensor_slices( |
| self._createTFRecordFiles()) |
| dataset = dataset.flat_map(readers.TFRecordDataset) |
| dataset = input_ops.auto_shard_dataset( |
| dataset, self._num_shards, self._shard_index) |
| |
| self._verifySimpleShardingOutput(dataset, self._record) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testInterleave(self): |
| dataset = dataset_ops.Dataset.from_tensor_slices( |
| self._createTFRecordFiles()) |
| dataset = dataset.interleave( |
| readers.TFRecordDataset, cycle_length=4, block_length=self._num_records) |
| dataset = input_ops.auto_shard_dataset( |
| dataset, self._num_shards, self._shard_index) |
| |
| # Since block_length == num records in each file, the output will still |
| # contain records in order of files. |
| self._verifySimpleShardingOutput(dataset, self._record) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testListfiles(self): |
| filenames = self._createTFRecordFiles() |
| file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt" |
| dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False) |
| dataset = dataset.flat_map(readers.TFRecordDataset) |
| dataset = input_ops.auto_shard_dataset( |
| dataset, self._num_shards, self._shard_index) |
| |
| next_element_fn = self._getNext(dataset) |
| actual, expected = [], [] |
| for f in range(self._shard_index, self._num_files, self._num_shards): |
| for r in range(self._num_records): |
| actual.append(self.evaluate(next_element_fn())) |
| expected.append(self._record(r, f)) |
| with self.assertRaises(errors.OutOfRangeError): |
| self.evaluate(next_element_fn()) |
| self.assertAllEqual(expected, actual) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testComplexPipeline(self): |
| # Setup a complex input pipeline. |
| batch_size = 2 |
| num_epochs = 5 |
| dataset = dataset_ops.Dataset.from_tensor_slices( |
| self._createTFRecordFiles()) |
| dataset = dataset.shuffle(buffer_size=self._num_files) |
| dataset = dataset.flat_map(readers.TFRecordDataset) |
| dataset = dataset.prefetch(buffer_size=batch_size) |
| dataset = dataset.shuffle(2 * self._num_files * self._num_records) |
| dataset = dataset.repeat(num_epochs) |
| dataset = dataset.map(lambda x: x) |
| dataset = dataset.batch(batch_size) |
| dataset = dataset.prefetch(buffer_size=None) |
| |
| # Auto shard. |
| dataset = input_ops.auto_shard_dataset( |
| dataset, self._num_shards, self._shard_index) |
| |
| # Verify output. |
| next_element_fn = self._getNext(dataset) |
| actual = [] |
| num_iterations = (self._num_files * self._num_records * num_epochs) // ( |
| self._num_shards * batch_size) |
| for _ in range(num_iterations): |
| actual.extend(self.evaluate(next_element_fn())) |
| with self.assertRaises(errors.OutOfRangeError): |
| self.evaluate(next_element_fn()) |
| |
| expected = [] |
| for f in range(0, self._num_files, self._num_shards): |
| for r in range(self._num_records): |
| expected.append(self._record(r, f)) |
| expected *= num_epochs |
| |
| self.assertAllEqual(sorted(expected), sorted(actual)) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testZip(self): |
| dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) |
| dataset2 = readers.TextLineDataset(self._createTextFiles()) |
| |
| dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) |
| dataset = input_ops.auto_shard_dataset( |
| dataset, self._num_shards, self._shard_index) |
| |
| record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f)) |
| self._verifySimpleShardingOutput(dataset, record_fn) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testConcat(self): |
| dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) |
| dataset2 = readers.TextLineDataset(self._createTextFiles()) |
| |
| dataset = dataset1.concatenate(dataset2) |
| dataset = input_ops.auto_shard_dataset( |
| dataset, self._num_shards, self._shard_index) |
| |
| next_element_fn = self._getNext(dataset) |
| for f in range(self._shard_index, self._num_files, self._num_shards): |
| for r in range(self._num_records): |
| self.assertAllEqual( |
| self._record(r, f), self.evaluate(next_element_fn())) |
| for f in range(self._shard_index, self._num_files, self._num_shards): |
| for r in range(self._num_records): |
| self.assertAllEqual( |
| self._text_line(r, f), self.evaluate(next_element_fn())) |
| with self.assertRaises(errors.OutOfRangeError): |
| self.evaluate(next_element_fn()) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testTextLineReader(self): |
| dataset = readers.TextLineDataset(self._createTextFiles()) |
| |
| dataset = input_ops.auto_shard_dataset( |
| dataset, self._num_shards, self._shard_index) |
| |
| self._verifySimpleShardingOutput(dataset, self._text_line) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testTextLineReaderWithFlatMap(self): |
| dataset = readers.TextLineDataset(self._createTextFiles()) |
| dataset = input_ops.auto_shard_dataset( |
| dataset, self._num_shards, self._shard_index) |
| |
| self._verifySimpleShardingOutput(dataset, self._text_line) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testFixedLengthReaderWithFlatMap(self): |
| dataset = readers.FixedLengthRecordDataset( |
| self._createFixedLengthRecordFiles(), self._record_bytes) |
| dataset = input_ops.auto_shard_dataset( |
| dataset, self._num_shards, self._shard_index) |
| |
| self._verifySimpleShardingOutput(dataset, self._fixed_length_record) |
| |
| |
| # A dataset that creates two variant tensors. |
| class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset): |
| |
| def __init__(self, input_dataset): |
| self._input_dataset = input_dataset |
| temp_variant_tensor = gen_dataset_ops.prefetch_dataset( |
| input_dataset._variant_tensor, |
| buffer_size=1, |
| **self._flat_structure) |
| variant_tensor = gen_dataset_ops.model_dataset( |
| temp_variant_tensor, **self._flat_structure) |
| super(_TestDataset, self).__init__(input_dataset, variant_tensor) |
| |
| |
| class CloneDatasetTest(test.TestCase): |
| |
| def _assert_datasets_equal(self, ds1, ds2): |
| # First lets assert the structure is the same. |
| self.assertTrue( |
| structure.are_compatible(ds1.element_spec, ds2.element_spec)) |
| |
| # Now create iterators on both and assert they produce the same values. |
| it1 = dataset_ops.make_initializable_iterator(ds1) |
| it2 = dataset_ops.make_initializable_iterator(ds2) |
| |
| get_next1 = it1.get_next() |
| get_next2 = it2.get_next() |
| |
| with self.cached_session(): |
| self.evaluate([it1.initializer, it2.initializer]) |
| val1, val2 = self.evaluate([get_next1, get_next2]) |
| self.assertEqual(val1, val2) |
| |
| @test_util.run_deprecated_v1 |
| def testOnlySource(self): |
| ds = dataset_ops.Dataset.range(10) |
| cloned_ds = input_ops._clone_dataset(ds) |
| self._assert_datasets_equal(ds, cloned_ds) |
| |
| @test_util.run_deprecated_v1 |
| def testSimplePipeline(self): |
| ds = dataset_ops.Dataset.range(10).map(math_ops.square) |
| cloned_ds = input_ops._clone_dataset(ds) |
| self._assert_datasets_equal(ds, cloned_ds) |
| |
| @test_util.run_deprecated_v1 |
| def testConcat(self): |
| ds1 = dataset_ops.Dataset.range(10) |
| ds2 = dataset_ops.Dataset.range(10) |
| ds = ds1.concatenate(ds2) |
| cloned_ds = input_ops._clone_dataset(ds) |
| self._assert_datasets_equal(ds, cloned_ds) |
| |
| @test_util.run_deprecated_v1 |
| def testZip(self): |
| ds1 = dataset_ops.Dataset.range(10) |
| ds2 = dataset_ops.Dataset.range(10) |
| ds = dataset_ops.Dataset.zip((ds1, ds2)) |
| cloned_ds = input_ops._clone_dataset(ds) |
| self._assert_datasets_equal(ds, cloned_ds) |
| |
| @test_util.run_deprecated_v1 |
| def testMultipleVariantTensors(self): |
| ds = dataset_ops.Dataset.range(10) |
| ds = _TestDataset(ds) |
| cloned_ds = input_ops._clone_dataset(ds) |
| self._assert_datasets_equal(ds, cloned_ds) |
| |
| |
| if __name__ == "__main__": |
| test.main() |