diff --git a/python/ray/data/datasource/parquet_datasink.py b/python/ray/data/datasource/parquet_datasink.py index 4056b04660f0..a8e085e5e0f3 100644 --- a/python/ray/data/datasource/parquet_datasink.py +++ b/python/ray/data/datasource/parquet_datasink.py @@ -6,6 +6,7 @@ from ray.data._internal.util import call_with_retry from ray.data.block import Block, BlockAccessor from ray.data.context import DataContext +from ray.data.datasource.file_based_datasource import _resolve_kwargs from ray.data.datasource.block_path_provider import BlockWritePathProvider from ray.data.datasource.file_datasink import _FileDatasink from ray.data.datasource.filename_provider import FilenameProvider @@ -68,11 +69,14 @@ def write( blocks[0], ctx.task_idx, 0 ) write_path = posixpath.join(self.path, filename) + write_kwargs = _resolve_kwargs( + self.arrow_parquet_args_fn, **self.arrow_parquet_args + ) def write_blocks_to_path(): with self.open_output_stream(write_path) as file: schema = BlockAccessor.for_block(blocks[0]).to_arrow().schema - with pq.ParquetWriter(file, schema) as writer: + with pq.ParquetWriter(file, schema, **write_kwargs) as writer: for block in blocks: table = BlockAccessor.for_block(block).to_arrow() writer.write_table(table) diff --git a/python/ray/data/tests/test_parquet.py b/python/ray/data/tests/test_parquet.py index 8020896365ed..4349c2560569 100644 --- a/python/ray/data/tests/test_parquet.py +++ b/python/ray/data/tests/test_parquet.py @@ -40,6 +40,19 @@ def check_num_computed(ds, streaming_expected) -> None: assert ds._plan.execute()._num_computed() == streaming_expected +def test_write_parquet_supports_gzip(ray_start_regular_shared, tmp_path): + ray.data.range(1).write_parquet(tmp_path, compression="gzip") + + # Test that all written files are gzip compressed. + for filename in os.listdir(tmp_path): + file_metadata = pq.ParquetFile(tmp_path / filename).metadata + compression = file_metadata.row_group(0).column(0).compression + assert compression == "GZIP", compression + + # Test that you can read the written files. + assert pq.read_table(tmp_path).to_pydict() == {"id": [0]} + + def test_include_paths(ray_start_regular_shared, tmp_path): path = os.path.join(tmp_path, "test.txt") table = pa.Table.from_pydict({"animals": ["cat", "dog"]})