Skip to content

Commit

Permalink
set data_workers 0; pin_memory
Browse files Browse the repository at this point in the history
  • Loading branch information
donghyeonk committed Apr 24, 2023
1 parent 3ff0332 commit 4e28f74
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def process_data(self, path, update_dict=False):

return total_data

def get_dataloader(self, batch_size=None, shuffle=True):
def get_dataloader(self, batch_size=None, shuffle=True, use_cuda=False):
if batch_size is None:
batch_size = self.config.batch_size

Expand All @@ -446,7 +446,7 @@ def get_dataloader(self, batch_size=None, shuffle=True):
sampler=train_sampler,
num_workers=self.config.data_workers,
collate_fn=self.batchify,
pin_memory=True
pin_memory=use_cuda,
)
else:
train_loader = None
Expand All @@ -462,7 +462,7 @@ def get_dataloader(self, batch_size=None, shuffle=True):
sampler=valid_sampler,
num_workers=self.config.data_workers,
collate_fn=self.batchify,
pin_memory=True
pin_memory=use_cuda,
)
else:
valid_loader = None
Expand All @@ -477,7 +477,7 @@ def get_dataloader(self, batch_size=None, shuffle=True):
sampler=test_sampler,
num_workers=self.config.data_workers,
collate_fn=self.batchify,
pin_memory=True
pin_memory=use_cuda,
)

return train_loader, valid_loader, test_loader
Expand Down Expand Up @@ -665,7 +665,7 @@ def __init__(self):
self.dur_size = 0
self.class_div = 0
self.slot_size = 0
self.data_workers = 4
self.data_workers = 0
self.save_dataset = False
self.sm_day_num = 7
self.sm_slot_num = 24
Expand Down

0 comments on commit 4e28f74

Please sign in to comment.