You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -15,49 +18,61 @@ class Nanoset(torch.utils.data.Dataset):
15
18
The Nanoset dataset
16
19
17
20
Args:
18
-
dataset_paths (List[str]): List of paths to tokenized datasets
21
+
dataset_folders (List[str]): List of folders with tokenized datasets
19
22
dataset_weights (List[float]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__
20
23
sequence_length (int): Sequence length of the built samples
21
-
token_dtype (Union[np.uint16, np.int32]): dtype of the tokens stored in the processed dataset files. np.uin16 for vocab sizes < 65535, np.int32 otherwise
24
+
token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise
22
25
train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size
23
26
"""
24
27
25
28
def__init__(
26
29
self,
27
-
dataset_paths: List[str],
30
+
dataset_folders: List[str],
28
31
dataset_weights: Union[List[float], None],
29
32
sequence_length: int,
30
-
token_dtype: Union[np.uint16, np.int32],
33
+
token_size: int,
31
34
train_split_num_samples: int,
32
35
random_seed: int=1234,
33
36
) ->None:
34
37
38
+
# Assertions
39
+
ifisinstance(dataset_folders, str):
40
+
warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]")
f"> Total number of samples from the {self.dataset_paths[index].rsplit('/', 1)[-1]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})",
140
+
f"> Total number of samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})",
0 commit comments