Skip to content

Commit 7707267

Browse files
authored
Spark/Lightning: add missing tranform_spec for Petastorm datamodule (horovod#3543)
Fix issue#3540 Signed-off-by: Chongxiao Cao <[email protected]>
1 parent 464c82e commit 7707267

File tree

4 files changed

+15
-7
lines changed

4 files changed

+15
-7
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
3434

3535
- Fallback to NCCL shared lib if static one is not found. ([#3500]((https://github.com/horovod/horovod/pull/3500))
3636

37+
- Spark/Lightning: add missing tranform_spec for Petastorm datamodule. ([#3543](https://github.com/horovod/horovod/pull/3543))
38+
3739
## [v0.24.3] - 2022-04-21
3840

3941
### Fixed

horovod/spark/lightning/datamodule.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class PetastormDataModule(pl.LightningDataModule):
1313
def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_val: bool=True,
1414
train_batch_size: int=32, val_batch_size: int=32, shuffle_size: int=1000,
1515
num_reader_epochs=None, reader_pool_type: str="process",
16-
reader_worker_count: int=2, transform_spec=None, inmemory_cache_all=False,
16+
reader_worker_count: int=2, transformation=None, inmemory_cache_all=False,
1717
cur_shard: int=0, shard_count: int=1, schema_fields=None, storage_options=None,
1818
steps_per_epoch_train: int=1, steps_per_epoch_val: int=1, verbose=True,
1919
debug_data_loader: bool=False, train_async_data_loader_queue_size: int=None,
@@ -29,7 +29,7 @@ def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_va
2929
self.num_reader_epochs = num_reader_epochs
3030
self.reader_pool_type = reader_pool_type
3131
self.reader_worker_count = reader_worker_count
32-
self.transform_spec = transform_spec
32+
self.transformation = transformation
3333
self.inmemory_cache_all = inmemory_cache_all
3434
self.cur_shard = cur_shard
3535
self.shard_count = shard_count
@@ -49,13 +49,15 @@ def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_va
4949
def setup(self, stage=None):
5050
# Assign train/val datasets for use in dataloaders
5151
if stage == 'fit' or stage is None:
52-
transform_spec = TransformSpec(self.transform_spec) if self.transform_spec else None
52+
transform_spec = TransformSpec(self.transformation) if self.transformation else None
5353
# In general, make_batch_reader is faster than make_reader for reading the dataset.
5454
# However, we found out that make_reader performs data transformations much faster than
5555
# make_batch_reader with parallel worker processes. Therefore, the default reader
5656
# we choose is make_batch_reader unless there are data transformations.
57+
reader_factory_kwargs = dict()
5758
if transform_spec:
5859
reader_factory = make_reader
60+
reader_factory_kwargs['pyarrow_serialize'] = True
5961
else:
6062
reader_factory = make_batch_reader
6163

@@ -64,15 +66,19 @@ def setup(self, stage=None):
6466
hdfs_driver=PETASTORM_HDFS_DRIVER,
6567
schema_fields=self.schema_fields,
6668
storage_options=self.storage_options,
69+
transform_spec=transform_spec,
6770
# Don't shuffle row groups without shuffling.
68-
shuffle_row_groups=True if self.shuffle_size > 0 else False)
71+
shuffle_row_groups=True if self.shuffle_size > 0 else False,
72+
**reader_factory_kwargs)
6973
if self.has_val:
7074
self.val_reader = reader_factory(self.val_dir, num_epochs=self.num_reader_epochs,
7175
cur_shard=self.cur_shard, shard_count=self.shard_count,
7276
hdfs_driver=PETASTORM_HDFS_DRIVER,
7377
schema_fields=self.schema_fields,
7478
storage_options=self.storage_options,
75-
shuffle_row_groups=False)
79+
transform_spec=transform_spec,
80+
shuffle_row_groups=False,
81+
**reader_factory_kwargs)
7682

7783
def teardown(self, stage=None):
7884
if stage == "fit" or stage is None:

horovod/spark/lightning/remote.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
266266
'num_reader_epochs': loader_num_epochs,
267267
'reader_pool_type': reader_pool_type,
268268
'reader_worker_count': train_reader_worker_count,
269-
'transform_spec': transformation,
269+
'transformation': transformation,
270270
'inmemory_cache_all': inmemory_cache_all,
271271
'cur_shard': hvd.rank(),
272272
'shard_count': hvd.size(),

horovod/spark/torch/remote.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def save_checkpoint():
264264
transform_spec=transform_spec,
265265
storage_options=storage_options,
266266
# Don't shuffle row groups without shuffling.
267-
shuffle_row_groups=True if shuffle_buffer_size > 0 else False
267+
shuffle_row_groups=True if shuffle_buffer_size > 0 else False,
268268
**reader_factory_kwargs) as train_reader:
269269
with reader_factory(remote_store.val_data_path,
270270
num_epochs=None,

0 commit comments

Comments
 (0)