1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import sys
6
+ import os
7
+ import os .path as osp
8
+ import glob
9
+ import re
10
+ import warnings
11
+
12
+ from torchreid .data .datasets import ImageDataset
13
+ from torchreid .utils import read_image
14
+ import cv2
15
+ import numpy as np
16
+
17
+ class Occluded_duke (ImageDataset ):
18
+
19
+ def __init__ (self , root = '' , ** kwargs ):
20
+ dataset_dir = 'ICME2018_Occluded-Person-Reidentification_datasets/Occluded_Duke'
21
+ self .root = osp .abspath (osp .expanduser (root ))
22
+ # self.dataset_dir = self.root
23
+ data_dir = osp .join (self .root , dataset_dir )
24
+ if osp .isdir (data_dir ):
25
+ self .data_dir = data_dir
26
+ else :
27
+ warnings .warn ('The current data structure is deprecated.' )
28
+ self .train_dir = osp .join (self .data_dir , 'bounding_box_train' )
29
+ self .query_dir = osp .join (self .data_dir , 'query' )
30
+ self .gallery_dir = osp .join (self .data_dir , 'bounding_box_test' )
31
+
32
+ train = self .process_dir (self .train_dir , relabel = True )
33
+ query = self .process_dir (self .query_dir , relabel = False )
34
+ gallery = self .process_dir (self .gallery_dir , relabel = False )
35
+ super (Occluded_duke , self ).__init__ (train , query , gallery , ** kwargs )
36
+ self .load_pose = isinstance (self .transform , tuple )
37
+ if self .load_pose :
38
+ self .train_pose_dir = osp .join (self .data_dir , 'bounding_box_train_pose' )
39
+ self .gallery_pose_dir = osp .join (self .data_dir , 'bounding_box_test_pose' )
40
+ self .query_pose_dir = osp .join (self .data_dir , 'query_pose' )
41
+ if self .mode == 'train' :
42
+ self .pose_dir = self .train_pose_dir
43
+ elif self .mode == 'query' :
44
+ self .pose_dir = self .query_pose_dir
45
+ elif self .mode == 'gallery' :
46
+ self .pose_dir = self .gallery_pose_dir
47
+ else :
48
+ raise ValueError ('Invalid mode. Got {}, but expected to be '
49
+ 'one of [train | query | gallery]' .format (self .mode ))
50
+
51
+ def process_dir (self , dir_path , relabel = False ):
52
+ img_paths = glob .glob (osp .join (dir_path , '*.jpg' ))
53
+ pattern = re .compile (r'([-\d]+)_c(\d)' )
54
+
55
+ pid_container = set ()
56
+ for img_path in img_paths :
57
+ pid , _ = map (int , pattern .search (img_path ).groups ())
58
+ pid_container .add (pid )
59
+ pid2label = {pid :label for label , pid in enumerate (pid_container )}
60
+
61
+ data = []
62
+ for img_path in img_paths :
63
+ pid , camid = map (int , pattern .search (img_path ).groups ())
64
+ assert 1 <= camid <= 8
65
+ camid -= 1 # index starts from 0
66
+ if relabel : pid = pid2label [pid ]
67
+ data .append ((img_path , pid , camid ))
68
+
69
+ return data
70
+
71
+ def __getitem__ (self , index ):
72
+ img_path , pid , camid = self .data [index ]
73
+ img = read_image (img_path )
74
+
75
+ if self .load_pose :
76
+ img_name = '.' .join (img_path .split ('/' )[- 1 ].split ('.' )[:- 1 ])
77
+ pose_pic_name = img_name + '_pose_heatmaps.png'
78
+ pose_pic_path = os .path .join (self .pose_dir , pose_pic_name )
79
+ pose = cv2 .imread (pose_pic_path , cv2 .IMREAD_GRAYSCALE )
80
+ pose = pose .reshape ((pose .shape [0 ], 56 , - 1 )).transpose ((0 ,2 ,1 )).astype ('float32' )
81
+ pose [:,:,18 :] = np .abs (pose [:,:,18 :]- 128 )
82
+ img , pose = self .transform [1 ](img , pose )
83
+ img = self .transform [0 ](img )
84
+ return img , pid , camid , img_path , pose
85
+ else :
86
+ if self .transform is not None :
87
+ img = self .transform (img )
88
+ return img , pid , camid , img_path
0 commit comments