@@ -13,7 +13,7 @@ class PetastormDataModule(pl.LightningDataModule):
13
13
def __init__ (self , train_dir : str , val_dir : str , num_train_epochs : int = 1 , has_val : bool = True ,
14
14
train_batch_size : int = 32 , val_batch_size : int = 32 , shuffle_size : int = 1000 ,
15
15
num_reader_epochs = None , reader_pool_type : str = "process" ,
16
- reader_worker_count : int = 2 , transform_spec = None , inmemory_cache_all = False ,
16
+ reader_worker_count : int = 2 , transformation = None , inmemory_cache_all = False ,
17
17
cur_shard : int = 0 , shard_count : int = 1 , schema_fields = None , storage_options = None ,
18
18
steps_per_epoch_train : int = 1 , steps_per_epoch_val : int = 1 , verbose = True ,
19
19
debug_data_loader : bool = False , train_async_data_loader_queue_size : int = None ,
@@ -29,7 +29,7 @@ def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_va
29
29
self .num_reader_epochs = num_reader_epochs
30
30
self .reader_pool_type = reader_pool_type
31
31
self .reader_worker_count = reader_worker_count
32
- self .transform_spec = transform_spec
32
+ self .transformation = transformation
33
33
self .inmemory_cache_all = inmemory_cache_all
34
34
self .cur_shard = cur_shard
35
35
self .shard_count = shard_count
@@ -49,13 +49,15 @@ def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_va
49
49
def setup (self , stage = None ):
50
50
# Assign train/val datasets for use in dataloaders
51
51
if stage == 'fit' or stage is None :
52
- transform_spec = TransformSpec (self .transform_spec ) if self .transform_spec else None
52
+ transform_spec = TransformSpec (self .transformation ) if self .transformation else None
53
53
# In general, make_batch_reader is faster than make_reader for reading the dataset.
54
54
# However, we found out that make_reader performs data transformations much faster than
55
55
# make_batch_reader with parallel worker processes. Therefore, the default reader
56
56
# we choose is make_batch_reader unless there are data transformations.
57
+ reader_factory_kwargs = dict ()
57
58
if transform_spec :
58
59
reader_factory = make_reader
60
+ reader_factory_kwargs ['pyarrow_serialize' ] = True
59
61
else :
60
62
reader_factory = make_batch_reader
61
63
@@ -64,15 +66,19 @@ def setup(self, stage=None):
64
66
hdfs_driver = PETASTORM_HDFS_DRIVER ,
65
67
schema_fields = self .schema_fields ,
66
68
storage_options = self .storage_options ,
69
+ transform_spec = transform_spec ,
67
70
# Don't shuffle row groups without shuffling.
68
- shuffle_row_groups = True if self .shuffle_size > 0 else False )
71
+ shuffle_row_groups = True if self .shuffle_size > 0 else False ,
72
+ ** reader_factory_kwargs )
69
73
if self .has_val :
70
74
self .val_reader = reader_factory (self .val_dir , num_epochs = self .num_reader_epochs ,
71
75
cur_shard = self .cur_shard , shard_count = self .shard_count ,
72
76
hdfs_driver = PETASTORM_HDFS_DRIVER ,
73
77
schema_fields = self .schema_fields ,
74
78
storage_options = self .storage_options ,
75
- shuffle_row_groups = False )
79
+ transform_spec = transform_spec ,
80
+ shuffle_row_groups = False ,
81
+ ** reader_factory_kwargs )
76
82
77
83
def teardown (self , stage = None ):
78
84
if stage == "fit" or stage is None :
0 commit comments