| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """TPU datasets tests.""" |
| |
| import os |
| |
| from tensorflow.core.protobuf import cluster_pb2 |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.python.client import session |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.data.ops import readers |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.lib.io import python_io |
| from tensorflow.python.platform import test |
| from tensorflow.python.tpu import datasets |
| from tensorflow.python.training import server_lib |
| from tensorflow.python.util import compat |
| |
| _NUM_FILES = 10 |
| _NUM_ENTRIES = 20 |
| |
| |
| class DatasetsTest(test.TestCase): |
| |
| def setUp(self): |
| super(DatasetsTest, self).setUp() |
| self._coord = server_lib.Server.create_local_server() |
| self._worker = server_lib.Server.create_local_server() |
| |
| self._cluster_def = cluster_pb2.ClusterDef() |
| worker_job = self._cluster_def.job.add() |
| worker_job.name = 'worker' |
| worker_job.tasks[0] = self._worker.target[len('grpc://'):] |
| coord_job = self._cluster_def.job.add() |
| coord_job.name = 'coordinator' |
| coord_job.tasks[0] = self._coord.target[len('grpc://'):] |
| |
| session_config = config_pb2.ConfigProto(cluster_def=self._cluster_def) |
| |
| self._sess = session.Session(self._worker.target, config=session_config) |
| self._worker_device = '/job:' + worker_job.name |
| |
| def testTextLineDataset(self): |
| all_contents = [] |
| for i in range(_NUM_FILES): |
| filename = os.path.join(self.get_temp_dir(), 'text_line.%d.txt' % i) |
| contents = [] |
| for j in range(_NUM_ENTRIES): |
| contents.append(compat.as_bytes('%d: %d' % (i, j))) |
| with open(filename, 'wb') as f: |
| f.write(b'\n'.join(contents)) |
| all_contents.extend(contents) |
| |
| dataset = datasets.StreamingFilesDataset( |
| os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text') |
| |
| with ops.device(self._worker_device): |
| iterator = dataset_ops.make_initializable_iterator(dataset) |
| self._sess.run(iterator.initializer) |
| get_next = iterator.get_next() |
| |
| retrieved_values = [] |
| for _ in range(4 * len(all_contents)): |
| retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) |
| |
| self.assertEqual(set(all_contents), set(retrieved_values)) |
| |
| def testTFRecordDataset(self): |
| all_contents = [] |
| for i in range(_NUM_FILES): |
| filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i) |
| writer = python_io.TFRecordWriter(filename) |
| for j in range(_NUM_ENTRIES): |
| record = compat.as_bytes('Record %d of file %d' % (j, i)) |
| writer.write(record) |
| all_contents.append(record) |
| writer.close() |
| |
| dataset = datasets.StreamingFilesDataset( |
| os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord') |
| |
| with ops.device(self._worker_device): |
| iterator = dataset_ops.make_initializable_iterator(dataset) |
| self._sess.run(iterator.initializer) |
| get_next = iterator.get_next() |
| |
| retrieved_values = [] |
| for _ in range(4 * len(all_contents)): |
| retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) |
| |
| self.assertEqual(set(all_contents), set(retrieved_values)) |
| |
| def testTFRecordDatasetFromDataset(self): |
| filenames = [] |
| all_contents = [] |
| for i in range(_NUM_FILES): |
| filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i) |
| filenames.append(filename) |
| writer = python_io.TFRecordWriter(filename) |
| for j in range(_NUM_ENTRIES): |
| record = compat.as_bytes('Record %d of file %d' % (j, i)) |
| writer.write(record) |
| all_contents.append(record) |
| writer.close() |
| |
| filenames = dataset_ops.Dataset.from_tensor_slices(filenames) |
| |
| dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord') |
| |
| with ops.device(self._worker_device): |
| iterator = dataset_ops.make_initializable_iterator(dataset) |
| self._sess.run(iterator.initializer) |
| get_next = iterator.get_next() |
| |
| retrieved_values = [] |
| for _ in range(4 * len(all_contents)): |
| retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) |
| |
| self.assertEqual(set(all_contents), set(retrieved_values)) |
| |
| def testArbitraryReaderFunc(self): |
| |
| def MakeRecord(i, j): |
| return compat.as_bytes('%04d-%04d' % (i, j)) |
| |
| record_bytes = len(MakeRecord(10, 200)) |
| |
| all_contents = [] |
| for i in range(_NUM_FILES): |
| filename = os.path.join(self.get_temp_dir(), 'fixed_length.%d' % i) |
| with open(filename, 'wb') as f: |
| for j in range(_NUM_ENTRIES): |
| record = MakeRecord(i, j) |
| f.write(record) |
| all_contents.append(record) |
| |
| def FixedLengthFile(filename): |
| return readers.FixedLengthRecordDataset(filename, record_bytes) |
| |
| dataset = datasets.StreamingFilesDataset( |
| os.path.join(self.get_temp_dir(), 'fixed_length*'), |
| filetype=FixedLengthFile) |
| |
| with ops.device(self._worker_device): |
| iterator = dataset_ops.make_initializable_iterator(dataset) |
| self._sess.run(iterator.initializer) |
| get_next = iterator.get_next() |
| |
| retrieved_values = [] |
| for _ in range(4 * len(all_contents)): |
| retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) |
| |
| self.assertEqual(set(all_contents), set(retrieved_values)) |
| |
| def testArbitraryReaderFuncFromDatasetGenerator(self): |
| |
| def my_generator(): |
| yield (1, [1] * 10) |
| |
| def gen_dataset(dummy): |
| return dataset_ops.Dataset.from_generator( |
| my_generator, (dtypes.int64, dtypes.int64), |
| (tensor_shape.TensorShape([]), tensor_shape.TensorShape([10]))) |
| |
| dataset = datasets.StreamingFilesDataset( |
| dataset_ops.Dataset.range(10), filetype=gen_dataset) |
| |
| with ops.device(self._worker_device): |
| iterator = dataset_ops.make_initializable_iterator(dataset) |
| self._sess.run(iterator.initializer) |
| get_next = iterator.get_next() |
| |
| retrieved_values = self._sess.run(get_next) |
| |
| self.assertIsInstance(retrieved_values, (list, tuple)) |
| self.assertEqual(len(retrieved_values), 2) |
| self.assertEqual(retrieved_values[0], 1) |
| self.assertItemsEqual(retrieved_values[1], [1] * 10) |
| |
| def testUnexpectedFiletypeString(self): |
| with self.assertRaises(ValueError): |
| datasets.StreamingFilesDataset( |
| os.path.join(self.get_temp_dir(), '*'), filetype='foo') |
| |
| def testUnexpectedFiletypeType(self): |
| with self.assertRaises(ValueError): |
| datasets.StreamingFilesDataset( |
| os.path.join(self.get_temp_dir(), '*'), filetype=3) |
| |
| def testUnexpectedFilesType(self): |
| with self.assertRaises(ValueError): |
| datasets.StreamingFilesDataset(123, filetype='tfrecord') |
| |
| |
| if __name__ == '__main__': |
| test.main() |