diff --git a/dataset.py b/dataset.py index 44f0429..79bcd60 100644 --- a/dataset.py +++ b/dataset.py @@ -431,7 +431,7 @@ def process_data(self, path, update_dict=False): return total_data - def get_dataloader(self, batch_size=None, shuffle=True): + def get_dataloader(self, batch_size=None, shuffle=True, use_cuda=False): if batch_size is None: batch_size = self.config.batch_size @@ -446,7 +446,7 @@ def get_dataloader(self, batch_size=None, shuffle=True): sampler=train_sampler, num_workers=self.config.data_workers, collate_fn=self.batchify, - pin_memory=True + pin_memory=use_cuda, ) else: train_loader = None @@ -462,7 +462,7 @@ def get_dataloader(self, batch_size=None, shuffle=True): sampler=valid_sampler, num_workers=self.config.data_workers, collate_fn=self.batchify, - pin_memory=True + pin_memory=use_cuda, ) else: valid_loader = None @@ -477,7 +477,7 @@ def get_dataloader(self, batch_size=None, shuffle=True): sampler=test_sampler, num_workers=self.config.data_workers, collate_fn=self.batchify, - pin_memory=True + pin_memory=use_cuda, ) return train_loader, valid_loader, test_loader @@ -665,7 +665,7 @@ def __init__(self): self.dur_size = 0 self.class_div = 0 self.slot_size = 0 - self.data_workers = 4 + self.data_workers = 0 self.save_dataset = False self.sm_day_num = 7 self.sm_slot_num = 24