-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Release the video code of PerSAM on DAVIS
- Loading branch information
1 parent
d4b42a8
commit b6f29e9
Showing
11 changed files
with
1,441 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from __future__ import absolute_import | ||
|
||
__version__ = '0.1.0' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
import os | ||
from glob import glob | ||
from collections import defaultdict | ||
import numpy as np | ||
from PIL import Image | ||
from torch.utils.data.dataset import Dataset | ||
from torch.utils.data import DataLoader | ||
from os import path | ||
|
||
def all_to_onehot(masks, labels): | ||
if len(masks.shape) == 3: | ||
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) | ||
else: | ||
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) | ||
|
||
for k, l in enumerate(labels): | ||
Ms[k] = (masks == l).astype(np.uint8) | ||
|
||
return Ms | ||
|
||
class DAVISTestDataset(Dataset): | ||
def __init__(self, root, imset='2017/val.txt', resolution=480, single_object=False, target_name=None): | ||
self.root = root | ||
if resolution == 480: | ||
res_tag = '480p' | ||
else: | ||
res_tag = 'Full-Resolution' | ||
self.mask_dir = path.join(root, 'Annotations', res_tag) | ||
self.mask480_dir = path.join(root, 'Annotations', '480p') | ||
self.image_dir = path.join(root, 'JPEGImages', res_tag) | ||
self.resolution = resolution | ||
_imset_dir = path.join(root, 'ImageSets') | ||
_imset_f = path.join(_imset_dir, imset) | ||
|
||
self.videos = [] | ||
self.num_frames = {} | ||
self.num_objects = {} | ||
self.shape = {} | ||
self.size_480p = {} | ||
with open(path.join(_imset_f), "r") as lines: | ||
for line in lines: | ||
_video = line.rstrip('\n') | ||
if target_name is not None and target_name != _video: | ||
continue | ||
self.videos.append(_video) | ||
self.num_frames[_video] = len(os.listdir(path.join(self.image_dir, _video))) | ||
_mask = np.array(Image.open(path.join(self.mask_dir, _video, '00000.png')).convert("P")) | ||
self.num_objects[_video] = np.max(_mask) | ||
self.shape[_video] = np.shape(_mask) | ||
_mask480 = np.array(Image.open(path.join(self.mask480_dir, _video, '00000.png')).convert("P")) | ||
self.size_480p[_video] = np.shape(_mask480) | ||
self.single_object = single_object | ||
|
||
|
||
def __len__(self): | ||
return len(self.videos) | ||
|
||
def __getitem__(self, index): | ||
video = self.videos[index] | ||
info = {} | ||
info['name'] = video | ||
info['frames'] = [] | ||
info['num_frames'] = self.num_frames[video] | ||
info['size_480p'] = self.size_480p[video] | ||
|
||
images = [] | ||
masks = [] | ||
for f in range(self.num_frames[video]): | ||
img_file = path.join(self.image_dir, video, '{:05d}.jpg'.format(f)) | ||
img = Image.open(img_file).convert('RGB') | ||
img = np.array(img, dtype = 'uint8') | ||
images.append(img) | ||
info['frames'].append('{:05d}.jpg'.format(f)) | ||
|
||
mask_file = path.join(self.mask_dir, video, '{:05d}.png'.format(f)) | ||
if path.exists(mask_file): | ||
m = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8) #(480, 910) | ||
masks.append(m) #(480, 910), numpy | ||
else: | ||
masks.append(np.zeros_like(masks[0])) | ||
|
||
images = np.stack(images, 0) | ||
masks = np.stack(masks, 0) | ||
|
||
if self.single_object: | ||
labels = [1] | ||
masks = (masks > 0.5).astype(np.uint8) | ||
masks = all_to_onehot(masks, labels) | ||
else: | ||
labels = np.unique(masks[0]) | ||
labels = labels[labels!=0] | ||
masks = all_to_onehot(masks, labels) | ||
|
||
info['labels'] = labels | ||
|
||
data = { | ||
'rgb': images, | ||
'gt': masks, | ||
'info': info, | ||
} | ||
|
||
return data | ||
|
||
|
||
class DAVIS(object): | ||
SUBSET_OPTIONS = ['train', 'val', 'test-dev', 'test-challenge'] | ||
TASKS = ['semi-supervised', 'unsupervised'] | ||
DATASET_WEB = 'https://davischallenge.org/davis2017/code.html' | ||
VOID_LABEL = 255 | ||
|
||
def __init__(self, root, task='unsupervised', subset='val', sequences='all', resolution='480p', codalab=False): | ||
""" | ||
Class to read the DAVIS dataset | ||
:param root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. | ||
:param task: Task to load the annotations, choose between semi-supervised or unsupervised. | ||
:param subset: Set to load the annotations | ||
:param sequences: Sequences to consider, 'all' to use all the sequences in a set. | ||
:param resolution: Specify the resolution to use the dataset, choose between '480' and 'Full-Resolution' | ||
""" | ||
if subset not in self.SUBSET_OPTIONS: | ||
raise ValueError(f'Subset should be in {self.SUBSET_OPTIONS}') | ||
if task not in self.TASKS: | ||
raise ValueError(f'The only tasks that are supported are {self.TASKS}') | ||
|
||
self.task = task | ||
self.subset = subset | ||
self.root = root | ||
self.img_path = os.path.join(self.root, 'JPEGImages', resolution) | ||
annotations_folder = 'Annotations' if task == 'semi-supervised' else 'Annotations_unsupervised' | ||
self.mask_path = os.path.join(self.root, annotations_folder, resolution) | ||
year = '2019' if task == 'unsupervised' and (subset == 'test-dev' or subset == 'test-challenge') else '2017' | ||
self.imagesets_path = os.path.join(self.root, 'ImageSets', year) | ||
|
||
self._check_directories() | ||
|
||
if sequences == 'all': | ||
with open(os.path.join(self.imagesets_path, f'{self.subset}.txt'), 'r') as f: | ||
tmp = f.readlines() | ||
sequences_names = [x.strip() for x in tmp] | ||
else: | ||
sequences_names = sequences if isinstance(sequences, list) else [sequences] | ||
self.sequences = defaultdict(dict) | ||
|
||
for seq in sequences_names: | ||
images = np.sort(glob(os.path.join(self.img_path, seq, '*.jpg'))).tolist() | ||
if len(images) == 0 and not codalab: | ||
raise FileNotFoundError(f'Images for sequence {seq} not found.') | ||
self.sequences[seq]['images'] = images | ||
masks = np.sort(glob(os.path.join(self.mask_path, seq, '*.png'))).tolist() | ||
masks.extend([-1] * (len(images) - len(masks))) | ||
self.sequences[seq]['masks'] = masks | ||
|
||
def _check_directories(self): | ||
if not os.path.exists(self.root): | ||
raise FileNotFoundError(f'DAVIS not found in the specified directory, download it from {self.DATASET_WEB}') | ||
if not os.path.exists(os.path.join(self.imagesets_path, f'{self.subset}.txt')): | ||
raise FileNotFoundError(f'Subset sequences list for {self.subset} not found, download the missing subset ' | ||
f'for the {self.task} task from {self.DATASET_WEB}') | ||
if self.subset in ['train', 'val'] and not os.path.exists(self.mask_path): | ||
raise FileNotFoundError(f'Annotations folder for the {self.task} task not found, download it from {self.DATASET_WEB}') | ||
|
||
def get_frames(self, sequence): | ||
for img, msk in zip(self.sequences[sequence]['images'], self.sequences[sequence]['masks']): | ||
image = np.array(Image.open(img)) | ||
mask = None if msk is None else np.array(Image.open(msk)) | ||
yield image, mask | ||
|
||
def _get_all_elements(self, sequence, obj_type): | ||
obj = np.array(Image.open(self.sequences[sequence][obj_type][0])) | ||
all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape)) | ||
obj_id = [] | ||
for i, obj in enumerate(self.sequences[sequence][obj_type]): | ||
all_objs[i, ...] = np.array(Image.open(obj)) | ||
obj_id.append(''.join(obj.split('/')[-1].split('.')[:-1])) | ||
return all_objs, obj_id | ||
|
||
def get_all_images(self, sequence): | ||
return self._get_all_elements(sequence, 'images') | ||
|
||
def get_all_masks(self, sequence, separate_objects_masks=False): | ||
masks, masks_id = self._get_all_elements(sequence, 'masks') | ||
masks_void = np.zeros_like(masks) | ||
|
||
# Separate void and object masks | ||
for i in range(masks.shape[0]): | ||
masks_void[i, ...] = masks[i, ...] == 255 | ||
masks[i, masks[i, ...] == 255] = 0 | ||
|
||
if separate_objects_masks: | ||
num_objects = int(np.max(masks[0, ...])) | ||
tmp = np.ones((num_objects, *masks.shape)) | ||
tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] | ||
masks = (tmp == masks[None, ...]) | ||
masks = masks > 0 | ||
return masks, masks_void, masks_id | ||
|
||
def get_sequences(self): | ||
for seq in self.sequences: | ||
yield seq | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import sys, os | ||
from tqdm import tqdm | ||
import warnings | ||
warnings.filterwarnings("ignore", category=RuntimeWarning) | ||
|
||
import numpy as np | ||
from davis2017.davis import DAVIS | ||
from davis2017.metrics import db_eval_boundary, db_eval_iou | ||
from davis2017 import utils | ||
from davis2017.results import Results | ||
from scipy.optimize import linear_sum_assignment | ||
|
||
|
||
class DAVISEvaluation(object): | ||
def __init__(self, davis_root, task, gt_set, sequences='all', codalab=False): | ||
""" | ||
Class to evaluate DAVIS sequences from a certain set and for a certain task | ||
:param davis_root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. | ||
:param task: Task to compute the evaluation, chose between semi-supervised or unsupervised. | ||
:param gt_set: Set to compute the evaluation | ||
:param sequences: Sequences to consider for the evaluation, 'all' to use all the sequences in a set. | ||
""" | ||
self.davis_root = davis_root | ||
self.task = task | ||
self.dataset = DAVIS(root=davis_root, task=task, subset=gt_set, sequences=sequences, codalab=codalab) | ||
|
||
@staticmethod | ||
def _evaluate_semisupervised(all_gt_masks, all_res_masks, all_void_masks, metric): | ||
if all_res_masks.shape[0] > all_gt_masks.shape[0]: | ||
sys.stdout.write("\nIn your PNG files there is an index higher than the number of objects in the sequence!") | ||
sys.exit() | ||
elif all_res_masks.shape[0] < all_gt_masks.shape[0]: | ||
zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:])) | ||
all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0) | ||
j_metrics_res, f_metrics_res = np.zeros(all_gt_masks.shape[:2]), np.zeros(all_gt_masks.shape[:2]) | ||
for ii in range(all_gt_masks.shape[0]): | ||
if 'J' in metric: | ||
j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks) | ||
if 'F' in metric: | ||
f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks) | ||
return j_metrics_res, f_metrics_res | ||
|
||
@staticmethod | ||
def _evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric, max_n_proposals=20): | ||
if all_res_masks.shape[0] > max_n_proposals: | ||
sys.stdout.write(f"\nIn your PNG files there is an index higher than the maximum number ({max_n_proposals}) of proposals allowed!") | ||
sys.exit() | ||
elif all_res_masks.shape[0] < all_gt_masks.shape[0]: | ||
zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:])) | ||
all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0) | ||
j_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1])) | ||
f_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1])) | ||
for ii in range(all_gt_masks.shape[0]): | ||
for jj in range(all_res_masks.shape[0]): | ||
if 'J' in metric: | ||
j_metrics_res[jj, ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks) | ||
if 'F' in metric: | ||
f_metrics_res[jj, ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks) | ||
if 'J' in metric and 'F' in metric: | ||
all_metrics = (np.mean(j_metrics_res, axis=2) + np.mean(f_metrics_res, axis=2)) / 2 | ||
else: | ||
all_metrics = np.mean(j_metrics_res, axis=2) if 'J' in metric else np.mean(f_metrics_res, axis=2) | ||
row_ind, col_ind = linear_sum_assignment(-all_metrics) | ||
return j_metrics_res[row_ind, col_ind, :], f_metrics_res[row_ind, col_ind, :] | ||
|
||
def evaluate(self, res_path, metric=('J', 'F'), debug=False): | ||
metric = metric if isinstance(metric, tuple) or isinstance(metric, list) else [metric] | ||
if 'T' in metric: | ||
raise ValueError('Temporal metric not supported!') | ||
if 'J' not in metric and 'F' not in metric: | ||
raise ValueError('Metric possible values are J for IoU or F for Boundary') | ||
|
||
# Containers | ||
metrics_res = {} | ||
if 'J' in metric: | ||
metrics_res['J'] = {"M": [], "R": [], "D": [], "M_per_object": {}} | ||
if 'F' in metric: | ||
metrics_res['F'] = {"M": [], "R": [], "D": [], "M_per_object": {}} | ||
|
||
# Sweep all sequences | ||
results = Results(root_dir=res_path) | ||
L = os.listdir(res_path) | ||
for seq in tqdm(L): | ||
print("Calculating Class", seq) | ||
# if seq == "car-roundabout": | ||
# break | ||
all_gt_masks, all_void_masks, all_masks_id = self.dataset.get_all_masks(seq, True) | ||
if self.task == 'semi-supervised': | ||
all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1] | ||
all_res_masks = results.read_masks(seq, all_masks_id) | ||
if self.task == 'unsupervised': | ||
j_metrics_res, f_metrics_res = self._evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric) | ||
elif self.task == 'semi-supervised': | ||
j_metrics_res, f_metrics_res = self._evaluate_semisupervised(all_gt_masks, all_res_masks, None, metric) | ||
for ii in range(all_gt_masks.shape[0]): | ||
seq_name = f'{seq}_{ii+1}' | ||
if 'J' in metric: | ||
[JM, JR, JD] = utils.db_statistics(j_metrics_res[ii]) | ||
metrics_res['J']["M"].append(JM) | ||
metrics_res['J']["R"].append(JR) | ||
metrics_res['J']["D"].append(JD) | ||
metrics_res['J']["M_per_object"][seq_name] = JM | ||
if 'F' in metric: | ||
[FM, FR, FD] = utils.db_statistics(f_metrics_res[ii]) | ||
metrics_res['F']["M"].append(FM) | ||
metrics_res['F']["R"].append(FR) | ||
metrics_res['F']["D"].append(FD) | ||
metrics_res['F']["M_per_object"][seq_name] = FM | ||
# break | ||
|
||
# Show progress | ||
if debug: | ||
sys.stdout.write(seq + '\n') | ||
sys.stdout.flush() | ||
return metrics_res |
Oops, something went wrong.