diff --git a/tensorflow_datasets/core/naming.py b/tensorflow_datasets/core/naming.py index 91f90341cee..fff8fcc14bb 100644 --- a/tensorflow_datasets/core/naming.py +++ b/tensorflow_datasets/core/naming.py @@ -666,7 +666,7 @@ def sharded_filepaths_pattern( `/path/dataset_name-split.fileformat@num_shards` or `/path/dataset_name-split@num_shards.fileformat` depending on the format. If `num_shards` is not given, then it returns - `/path/dataset_name-split.fileformat-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]`. + `/path/dataset_name-split.fileformat*`. Args: num_shards: optional specification of the number of shards. @@ -681,7 +681,7 @@ def sharded_filepaths_pattern( elif use_at_notation: replacement = '@*' else: - replacement = '-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]' + replacement = '*' return _replace_shard_pattern(os.fspath(a_filepath), replacement) def glob_pattern(self, num_shards: int | None = None) -> str: diff --git a/tensorflow_datasets/core/naming_test.py b/tensorflow_datasets/core/naming_test.py index 64b0c790a2c..1859dbcdfbc 100644 --- a/tensorflow_datasets/core/naming_test.py +++ b/tensorflow_datasets/core/naming_test.py @@ -459,7 +459,7 @@ def test_sharded_file_template_shard_index(): ) assert ( os.fspath(template.sharded_filepaths_pattern()) - == '/my/path/data/mnist-train.tfrecord-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]' + == '/my/path/data/mnist-train.tfrecord*' ) assert ( os.fspath(template.sharded_filepaths_pattern(num_shards=100)) @@ -474,10 +474,7 @@ def test_glob_pattern(): filetype_suffix='tfrecord', data_dir=epath.Path('/data'), ) - assert ( - '/data/ds-train.tfrecord-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]' - == template.glob_pattern() - ) + assert '/data/ds-train.tfrecord*' == template.glob_pattern() assert '/data/ds-train.tfrecord-*-of-00042' == template.glob_pattern( num_shards=42 ) diff --git a/tensorflow_datasets/core/writer.py b/tensorflow_datasets/core/writer.py index 45237d32e45..af5b76a0214 100644 --- a/tensorflow_datasets/core/writer.py +++ b/tensorflow_datasets/core/writer.py @@ -816,9 +816,8 @@ def finalize(self) -> tuple[list[int], int]: logging.info("Finalizing writer for %s", self._filename_template.split) # We don't know the number of shards, the length of each shard, nor the # total size, so we compute them here. - shards = self._filename_template.data_dir.glob( - self._filename_template.glob_pattern() - ) + prefix = epath.Path(self._filename_template.filepath_prefix()) + shards = self._filename_template.data_dir.glob(f"{prefix.name}*") def _get_length_and_size(shard: epath.Path) -> tuple[epath.Path, int, int]: length = self._file_adapter.num_examples(shard) diff --git a/tensorflow_datasets/core/writer_test.py b/tensorflow_datasets/core/writer_test.py index d742045df86..b9fc7f18340 100644 --- a/tensorflow_datasets/core/writer_test.py +++ b/tensorflow_datasets/core/writer_test.py @@ -592,47 +592,39 @@ def test_write_beam(self, file_format: file_adapters.FileFormat): with tempfile.TemporaryDirectory() as tmp_dir: tmp_dir = epath.Path(tmp_dir) - - def get_writer(split): - filename_template = naming.ShardedFileTemplate( - dataset_name='foo', - split=split, - filetype_suffix=file_format.file_suffix, - data_dir=tmp_dir, - ) - return writer_lib.NoShuffleBeamWriter( - serializer=testing.DummySerializer('dummy specs'), - filename_template=filename_template, - file_format=file_format, - ) - + filename_template = naming.ShardedFileTemplate( + dataset_name='foo', + split='train', + filetype_suffix=file_format.file_suffix, + data_dir=tmp_dir, + ) + writer = writer_lib.NoShuffleBeamWriter( + serializer=testing.DummySerializer('dummy specs'), + filename_template=filename_template, + file_format=file_format, + ) to_write = [(i, str(i).encode('utf-8')) for i in range(10)] # Here we need to disable type check as `beam.Create` is not capable of # inferring the type of the PCollection elements. options = beam.options.pipeline_options.PipelineOptions( pipeline_type_check=False ) - writers = [get_writer(split) for split in ('train-b', 'train')] - - for writer in writers: - with beam.Pipeline(options=options, runner=_get_runner()) as pipeline: - - @beam.ptransform_fn - def _build_pcollection(pipeline, writer): - pcollection = pipeline | 'Start' >> beam.Create(to_write) - return writer.write_from_pcollection(pcollection) - - _ = pipeline | 'test' >> _build_pcollection(writer) - + with beam.Pipeline(options=options, runner=_get_runner()) as pipeline: + + @beam.ptransform_fn + def _build_pcollection(pipeline): + pcollection = pipeline | 'Start' >> beam.Create(to_write) + return writer.write_from_pcollection(pcollection) + + _ = pipeline | 'test' >> _build_pcollection() # pylint: disable=no-value-for-parameter + shard_lengths, total_size = writer.finalize() + self.assertNotEmpty(shard_lengths) + self.assertEqual(sum(shard_lengths), 10) + self.assertGreater(total_size, 10) files = list(tmp_dir.iterdir()) - self.assertGreaterEqual(len(files), 2) + self.assertGreaterEqual(len(files), 1) for f in files: self.assertIn(file_format.file_suffix, f.name) - for writer in writers: - shard_lengths, total_size = writer.finalize() - self.assertNotEmpty(shard_lengths) - self.assertEqual(sum(shard_lengths), 10) - self.assertGreater(total_size, 10) class CustomExampleWriter(writer_lib.ExampleWriter):