-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathherbarium_19.py
164 lines (110 loc) · 5.83 KB
/
herbarium_19.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import torchvision
import numpy as np
from copy import deepcopy
from data.data_utils import subsample_instances
from config import herbarium_dataroot
class HerbariumDataset19(torchvision.datasets.ImageFolder):
def __init__(self, *args, **kwargs):
# Process metadata json for training images into a DataFrame
super().__init__(*args, **kwargs)
self.uq_idxs = np.array(range(len(self)))
def __getitem__(self, idx):
img, label = super().__getitem__(idx)
uq_idx = self.uq_idxs[idx]
return img, label, uq_idx
def subsample_dataset(dataset, idxs):
mask = np.zeros(len(dataset)).astype('bool')
mask[idxs] = True
dataset.samples = np.array(dataset.samples)[mask].tolist()
dataset.targets = np.array(dataset.targets)[mask].tolist()
dataset.uq_idxs = dataset.uq_idxs[mask]
dataset.samples = [[x[0], int(x[1])] for x in dataset.samples]
dataset.targets = [int(x) for x in dataset.targets]
return dataset
def subsample_classes(dataset, include_classes=range(250)):
cls_idxs = [x for x, l in enumerate(dataset.targets) if l in include_classes]
target_xform_dict = {}
for i, k in enumerate(include_classes):
target_xform_dict[k] = i
dataset = subsample_dataset(dataset, cls_idxs)
dataset.target_transform = lambda x: target_xform_dict[x]
return dataset
def get_train_val_indices(train_dataset, val_instances_per_class=5):
train_classes = list(set(train_dataset.targets))
# Get train/test indices
train_idxs = []
val_idxs = []
for cls in train_classes:
cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0]
# Have a balanced test set
v_ = np.random.choice(cls_idxs, replace=False, size=(val_instances_per_class,))
t_ = [x for x in cls_idxs if x not in v_]
train_idxs.extend(t_)
val_idxs.extend(v_)
return train_idxs, val_idxs
def get_herbarium_datasets(train_transform, test_transform, train_classes=range(500), prop_train_labels=0.8,
seed=0, split_train_val=False):
np.random.seed(seed)
# Init entire training set
train_dataset = HerbariumDataset19(transform=train_transform,
root=os.path.join(herbarium_dataroot, 'small-train'))
# Get labelled training set which has subsampled classes, then subsample some indices from that
# TODO: Subsampling unlabelled set in uniform random fashion from training data, will contain many instances of dominant class
train_dataset_labelled = subsample_classes(deepcopy(train_dataset), include_classes=train_classes)
subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
# Split into training and validation sets
if split_train_val:
train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled,
val_instances_per_class=5)
train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
val_dataset_labelled_split.transform = test_transform
else:
train_dataset_labelled_split, val_dataset_labelled_split = None, None
# Get unlabelled data
unlabelled_indices = set(train_dataset.uq_idxs) - set(train_dataset_labelled.uq_idxs)
train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset), np.array(list(unlabelled_indices)))
# Get test dataset
test_dataset = HerbariumDataset19(transform=test_transform,
root=os.path.join(herbarium_dataroot, 'small-validation'))
# Transform dict
unlabelled_classes = list(set(train_dataset.targets) - set(train_classes))
target_xform_dict = {}
for i, k in enumerate(list(train_classes) + unlabelled_classes):
target_xform_dict[k] = i
test_dataset.target_transform = lambda x: target_xform_dict[x]
train_dataset_unlabelled.target_transform = lambda x: target_xform_dict[x]
# Either split train into train and val or use test set as val
train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
all_datasets = {
'train_labelled': train_dataset_labelled,
'train_unlabelled': train_dataset_unlabelled,
'val': val_dataset_labelled,
'test': test_dataset,
}
return all_datasets
if __name__ == '__main__':
np.random.seed(0)
train_classes = np.random.choice(range(683,), size=(int(683 / 2)), replace=False)
x = get_herbarium_datasets(None, None, train_classes=train_classes,
prop_train_labels=0.5)
assert set(x['train_unlabelled'].targets) == set(range(683))
print('Printing lens...')
for k, v in x.items():
if v is not None:
print(f'{k}: {len(v)}')
print('Printing labelled and unlabelled overlap...')
print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
print('Printing total instances in train...')
print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
print('Printing number of labelled classes...')
print(len(set(x['train_labelled'].targets)))
print('Printing total number of classes...')
print(len(set(x['train_unlabelled'].targets)))
print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
print(f'Len labelled set: {len(x["train_labelled"])}')
print(f'Len unlabelled set: {len(x["train_unlabelled"])}')