-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
79 lines (62 loc) · 3.05 KB
/
data.py
File metadata and controls
79 lines (62 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from torchvision import transforms
from torch.utils.data import dataset, dataloader
from torchvision.datasets.folder import default_loader
from utils.RandomErasing import RandomErasing
from utils.RandomSampler import RandomSampler
from opt import opt
import os
import re
class Data():
def __init__(self):
train_transform = transforms.Compose([
#transforms.Resize((384, 128), interpolation=3),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
RandomErasing(probability=0.5, mean=[0.0, 0.0, 0.0])
])
test_transform = transforms.Compose([
#transforms.Resize((384, 128), interpolation=3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.trainset = Market1501(train_transform, 'train', opt.data_path)
self.testset = Market1501(test_transform, 'test', opt.data_path)
self.queryset = Market1501(test_transform, 'query', opt.data_path)
self.train_loader = dataloader.DataLoader(self.trainset,
sampler=RandomSampler(self.trainset, batch_id=opt.batchid,
batch_image=opt.batchimage),
batch_size=opt.batchid * opt.batchimage, num_workers=8,
pin_memory=True)
self.test_loader = dataloader.DataLoader(self.testset, batch_size=opt.batchtest, num_workers=8, pin_memory=True)
self.query_loader = dataloader.DataLoader(self.queryset, batch_size=opt.batchtest, num_workers=8,
pin_memory=True)
if opt.mode == 'vis':
self.query_image = test_transform(default_loader(opt.query_image))
class Market1501(dataset.Dataset):
def __init__(self, transform, dtype, data_path):
self.transform = transform
self.loader = default_loader
self.data_path = data_path
#if dtype == 'train':
# self.data_path += '\\bounding_box_train\\'
#elif dtype == 'test':
# self.data_path += '\\bounding_box_train\\'
#else:
# self.data_path += '\\query'
self.imgs = [path for path in self.list_pictures(self.data_path)]
def __getitem__(self, index):
path = self.imgs[index]
target = 0
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.imgs)
@staticmethod
def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm'):
assert os.path.isdir(directory), 'dataset is not exists!{}'.format(directory)
return sorted([os.path.join(root, f)
for root, _, files in os.walk(directory) for f in files
if re.match(r'([\w]+\.(?:' + ext + '))', f)])