Skip to content

Commit 788b37f

Browse files
authored
[Feature] Support NYU depth estimation dataset (open-mmlab#3269)
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Please describe the motivation of this PR and the goal you want to achieve through this PR. ## Modification Please briefly describe what modification is made in this PR. 1. add `NYUDataset`class 2. add script to process NYU dataset 3. add transforms for loading depth map 4. add docs & unittest ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 5. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 6. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 7. The documentation has been modified accordingly, like docstring or example tutorials.
1 parent 9277418 commit 788b37f

File tree

13 files changed

+367
-8
lines changed

13 files changed

+367
-8
lines changed

docs/en/user_guides/2_dataset_prepare.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,13 @@ mmsegmentation
198198
| │   │   │ └── rles
199199
| │ │ │ │ ├──sem_seg_train.json
200200
| │ │ │ │ └──sem_seg_val.json
201+
│ ├── nyu
202+
│ │ ├── images
203+
│ │ │ ├── train
204+
│ │ │ ├── test
205+
│ │ ├── annotations
206+
│ │ │ ├── train
207+
│ │ │ ├── test
201208
```
202209

203210
## Download dataset via MIM
@@ -735,3 +742,13 @@ mmsegmentation
735742
| │ │ │ │ ├──sem_seg_train.json
736743
| │ │ │ │ └──sem_seg_val.json
737744
```
745+
746+
## NYU
747+
748+
- To access the NYU dataset, you can download it from [this link](https://drive.google.com/file/d/1wC-io-14RCIL4XTUrQLk6lBqU2AexLVp/view?usp=share_link)
749+
750+
- Once the download is complete, you can utilize the [tools/dataset_converters/nyu.py](/tools/dataset_converters/nyu.py) script to extract and organize the data into the required format. Run the following command in your terminal:
751+
752+
```bash
753+
python tools/dataset_converters/nyu.py nyu.zip
754+
```

docs/zh_cn/user_guides/2_dataset_prepare.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,13 @@ mmsegmentation
198198
| │   │   │ └── rles
199199
| │ │ │ │ ├──sem_seg_train.json
200200
| │ │ │ │ └──sem_seg_val.json
201+
│ ├── nyu
202+
│ │ ├── images
203+
│ │ │ ├── train
204+
│ │ │ ├── test
205+
│ │ ├── annotations
206+
│ │ │ ├── train
207+
│ │ │ ├── test
201208
```
202209

203210
## 用 MIM 下载数据集
@@ -731,3 +738,13 @@ mmsegmentation
731738
| │ │ │ │ ├──sem_seg_train.json
732739
| │ │ │ │ └──sem_seg_val.json
733740
```
741+
742+
## NYU
743+
744+
- 您可以从 [这个链接](https://drive.google.com/file/d/1wC-io-14RCIL4XTUrQLk6lBqU2AexLVp/view?usp=share_link) 下载 NYU 数据集
745+
746+
- 下载完成后,您可以使用 [tools/dataset_converters/nyu.py](/tools/dataset_converters/nyu.py) 脚本来解压和组织数据到所需的格式
747+
748+
```bash
749+
python tools/dataset_converters/nyu.py nyu.zip
750+
```

mmseg/datasets/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .loveda import LoveDADataset
2020
from .mapillary import MapillaryDataset_v1, MapillaryDataset_v2
2121
from .night_driving import NightDrivingDataset
22+
from .nyu import NYUDataset
2223
from .pascal_context import PascalContextDataset, PascalContextDataset59
2324
from .potsdam import PotsdamDataset
2425
from .refuge import REFUGEDataset
@@ -58,5 +59,6 @@
5859
'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
5960
'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset',
6061
'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
61-
'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset'
62+
'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset',
63+
'NYUDataset'
6264
]

mmseg/datasets/nyu.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os.path as osp
3+
from typing import List
4+
5+
import mmengine.fileio as fileio
6+
7+
from mmseg.registry import DATASETS
8+
from .basesegdataset import BaseSegDataset
9+
10+
11+
@DATASETS.register_module()
12+
class NYUDataset(BaseSegDataset):
13+
"""NYU depth estimation dataset. The file structure should be.
14+
15+
.. code-block:: none
16+
17+
├── data
18+
│ ├── nyu
19+
│ │ ├── images
20+
│ │ │ ├── train
21+
│ │ │ │ ├── scene_xxx.jpg
22+
│ │ │ │ ├── ...
23+
│ │ │ ├── test
24+
│ │ ├── annotations
25+
│ │ │ ├── train
26+
│ │ │ │ ├── scene_xxx.png
27+
│ │ │ │ ├── ...
28+
│ │ │ ├── test
29+
30+
Args:
31+
ann_file (str): Annotation file path. Defaults to ''.
32+
metainfo (dict, optional): Meta information for dataset, such as
33+
specify classes to load. Defaults to None.
34+
data_root (str, optional): The root directory for ``data_prefix`` and
35+
``ann_file``. Defaults to None.
36+
data_prefix (dict, optional): Prefix for training data. Defaults to
37+
dict(img_path='images', depth_map_path='annotations').
38+
img_suffix (str): Suffix of images. Default: '.jpg'
39+
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
40+
filter_cfg (dict, optional): Config for filter data. Defaults to None.
41+
indices (int or Sequence[int], optional): Support using first few
42+
data in annotation file to facilitate training/testing on a smaller
43+
dataset. Defaults to None which means using all ``data_infos``.
44+
serialize_data (bool, optional): Whether to hold memory using
45+
serialized objects, when enabled, data loader workers can use
46+
shared RAM from master process instead of making a copy. Defaults
47+
to True.
48+
pipeline (list, optional): Processing pipeline. Defaults to [].
49+
test_mode (bool, optional): ``test_mode=True`` means in test phase.
50+
Defaults to False.
51+
lazy_init (bool, optional): Whether to load annotation during
52+
instantiation. In some cases, such as visualization, only the meta
53+
information of the dataset is needed, which is not necessary to
54+
load annotation file. ``Basedataset`` can skip load annotations to
55+
save time by set ``lazy_init=True``. Defaults to False.
56+
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
57+
None img. The maximum extra number of cycles to get a valid
58+
image. Defaults to 1000.
59+
ignore_index (int): The label index to be ignored. Default: 255
60+
reduce_zero_label (bool): Whether to mark label zero as ignored.
61+
Default to False.
62+
backend_args (dict, Optional): Arguments to instantiate a file backend.
63+
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
64+
for details. Defaults to None.
65+
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
66+
"""
67+
METAINFO = dict(
68+
classes=('printer_room', 'bathroom', 'living_room', 'study',
69+
'conference_room', 'study_room', 'kitchen', 'home_office',
70+
'bedroom', 'dinette', 'playroom', 'indoor_balcony',
71+
'laundry_room', 'basement', 'excercise_room', 'foyer',
72+
'home_storage', 'cafe', 'furniture_store', 'office_kitchen',
73+
'student_lounge', 'dining_room', 'reception_room',
74+
'computer_lab', 'classroom', 'office', 'bookstore'))
75+
76+
def __init__(self,
77+
data_prefix=dict(
78+
img_path='images', depth_map_path='annotations'),
79+
img_suffix='.jpg',
80+
depth_map_suffix='.png',
81+
**kwargs) -> None:
82+
super().__init__(
83+
data_prefix=data_prefix,
84+
img_suffix=img_suffix,
85+
seg_map_suffix=depth_map_suffix,
86+
**kwargs)
87+
88+
def _get_category_id_from_filename(self, image_fname: str) -> int:
89+
"""Retrieve the category ID from the given image filename."""
90+
image_fname = osp.basename(image_fname)
91+
position = image_fname.find(next(filter(str.isdigit, image_fname)), 0)
92+
categoty_name = image_fname[:position - 1]
93+
if categoty_name not in self._metainfo['classes']:
94+
return -1
95+
else:
96+
return self._metainfo['classes'].index(categoty_name)
97+
98+
def load_data_list(self) -> List[dict]:
99+
"""Load annotation from directory or annotation file.
100+
101+
Returns:
102+
list[dict]: All data info of dataset.
103+
"""
104+
data_list = []
105+
img_dir = self.data_prefix.get('img_path', None)
106+
ann_dir = self.data_prefix.get('depth_map_path', None)
107+
108+
_suffix_len = len(self.img_suffix)
109+
for img in fileio.list_dir_or_file(
110+
dir_path=img_dir,
111+
list_dir=False,
112+
suffix=self.img_suffix,
113+
recursive=True,
114+
backend_args=self.backend_args):
115+
data_info = dict(img_path=osp.join(img_dir, img))
116+
if ann_dir is not None:
117+
depth_map = img[:-_suffix_len] + self.seg_map_suffix
118+
data_info['depth_map_path'] = osp.join(ann_dir, depth_map)
119+
data_info['seg_fields'] = []
120+
data_info['category_id'] = self._get_category_id_from_filename(img)
121+
data_list.append(data_info)
122+
data_list = sorted(data_list, key=lambda x: x['img_path'])
123+
return data_list

mmseg/datasets/transforms/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from .formatting import PackSegInputs
33
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
44
LoadBiomedicalData, LoadBiomedicalImageFromFile,
5-
LoadImageFromNDArray, LoadMultipleRSImageFromFile,
6-
LoadSingleRSImageFromFile)
5+
LoadDepthAnnotation, LoadImageFromNDArray,
6+
LoadMultipleRSImageFromFile, LoadSingleRSImageFromFile)
77
# yapf: disable
88
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
99
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
@@ -24,5 +24,5 @@
2424
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
2525
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
2626
'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput',
27-
'LoadMultipleRSImageFromFile'
27+
'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation'
2828
]

mmseg/datasets/transforms/formatting.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ def transform(self, results: dict) -> dict:
9292
...].astype(np.int64)))
9393
data_sample.set_data(dict(gt_edge_map=PixelData(**gt_edge_data)))
9494

95+
if 'gt_depth_map' in results:
96+
gt_depth_data = dict(
97+
data=to_tensor(results['gt_depth_map'][None, ...]))
98+
data_sample.set_data(dict(gt_depth_map=PixelData(**gt_depth_data)))
99+
95100
img_meta = {}
96101
for key in self.meta_keys:
97102
if key in results:

mmseg/datasets/transforms/loading.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,3 +625,77 @@ def __repr__(self):
625625
repr_str = (f'{self.__class__.__name__}('
626626
f'to_float32={self.to_float32})')
627627
return repr_str
628+
629+
630+
@TRANSFORMS.register_module()
631+
class LoadDepthAnnotation(BaseTransform):
632+
"""Load ``depth_map`` annotation provided by depth estimation dataset.
633+
634+
The annotation format is as the following:
635+
636+
.. code-block:: python
637+
638+
{
639+
'gt_depth_map': np.ndarray [Y, X]
640+
}
641+
642+
Required Keys:
643+
644+
- seg_depth_path
645+
646+
Added Keys:
647+
648+
- gt_depth_map (np.ndarray): Depth map with shape (Y, X) by
649+
default, and data type is float32 if set to_float32 = True.
650+
651+
Args:
652+
decode_backend (str): The data decoding backend type. Options are
653+
'numpy', 'nifti', and 'cv2'. Defaults to 'cv2'.
654+
to_float32 (bool): Whether to convert the loaded depth map to a float32
655+
numpy array. If set to False, the loaded image is an uint16 array.
656+
Defaults to True.
657+
depth_rescale_factor (float): Factor to rescale the depth value to
658+
limit the range. Defaults to 1.0.
659+
backend_args (dict, Optional): Arguments to instantiate a file backend.
660+
See :class:`mmengine.fileio` for details.
661+
Defaults to None.
662+
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
663+
"""
664+
665+
def __init__(self,
666+
decode_backend: str = 'cv2',
667+
to_float32: bool = True,
668+
depth_rescale_factor: float = 1.0,
669+
backend_args: Optional[dict] = None) -> None:
670+
super().__init__()
671+
self.decode_backend = decode_backend
672+
self.to_float32 = to_float32
673+
self.depth_rescale_factor = depth_rescale_factor
674+
self.backend_args = backend_args.copy() if backend_args else None
675+
676+
def transform(self, results: Dict) -> Dict:
677+
"""Functions to load depth map.
678+
679+
Args:
680+
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
681+
682+
Returns:
683+
dict: The dict contains loaded depth map.
684+
"""
685+
data_bytes = fileio.get(results['depth_map_path'], self.backend_args)
686+
gt_depth_map = datafrombytes(data_bytes, backend=self.decode_backend)
687+
688+
if self.to_float32:
689+
gt_depth_map = gt_depth_map.astype(np.float32)
690+
691+
gt_depth_map *= self.depth_rescale_factor
692+
results['gt_depth_map'] = gt_depth_map
693+
results['seg_fields'].append('gt_depth_map')
694+
return results
695+
696+
def __repr__(self):
697+
repr_str = (f'{self.__class__.__name__}('
698+
f"decode_backend='{self.decode_backend}', "
699+
f'to_float32={self.to_float32}, '
700+
f'backend_args={self.backend_args})')
701+
return repr_str

mmseg/utils/io.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import io
44
import pickle
55

6+
import cv2
67
import numpy as np
78

89

@@ -12,7 +13,7 @@ def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray:
1213
Args:
1314
content (bytes): The data bytes got from files or other streams.
1415
backend (str): The data decoding backend type. Options are 'numpy',
15-
'nifti' and 'pickle'. Defaults to 'numpy'.
16+
'nifti', 'cv2' and 'pickle'. Defaults to 'numpy'.
1617
1718
Returns:
1819
numpy.ndarray: Loaded data array.
@@ -33,6 +34,9 @@ def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray:
3334
data = Nifti1Image.from_bytes(data.to_bytes()).get_fdata()
3435
elif backend == 'numpy':
3536
data = np.load(f)
37+
elif backend == 'cv2':
38+
data = np.frombuffer(f.read(), dtype=np.uint16)
39+
data = cv2.imdecode(data, 2)
3640
else:
3741
raise ValueError
3842
return data
Loading
Loading

0 commit comments

Comments
 (0)