Skip to content

Commit

Permalink
Baseline Update
Browse files Browse the repository at this point in the history
  • Loading branch information
ThuanNaN committed Mar 6, 2023
1 parent da365d0 commit c3507c2
Show file tree
Hide file tree
Showing 25 changed files with 1,656 additions and 95 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ __pycache__/

# C extensions
*.so

data/
checkpoint/
# Distribution / packaging
.Python
build/
Expand Down
1 change: 1 addition & 0 deletions access.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
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
36 changes: 22 additions & 14 deletions configs/baseline_config.json
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
{
"general_config": {
"data_dir": "/kaggle/input/",
"data_dir": "./data/AIC23_Track2_NL_Retrieval/data",
"max_len": 64,
"train_batch_size": 4,
"valid_batch_size": 4,
"train_batch_size": 16,
"valid_batch_size": 16,
"n_workers": 8,
"kfolds": 5,
"gradient_checkpointing": true,
"epochs": 1,
"n_warmup_steps": 50,
"gradient_accumulation_steps": 4,
"unscale": false,
"epochs": 50,
"n_warmup_steps": 0,
"gradient_accumulation_steps": 1,
"unscale": true,
"evaluate_n_times_per_epoch": 1,
"max_grad_norm": 1000,
"train_print_frequency": 100,
"valid_print_frequency": 100,
"loss": "DCL"
"train_print_frequency": 20,
"valid_print_frequency": 20,
"loss": "InfoNCE",
"load_checkpoint": null
},
"optimizer": {
"weight_decay": 5e-5,
"learning_rate": 3e-6,
"learning_rate": 0.00003,
"eps": 1e-8,
"betas": [0.9, 0.999]
},
"scheduler": {
"scheduler_type": "cosine_schedule_with_warmup",
"scheduler_type": "linear_warmup_cosine_annealing_lr",
"batch_scheduler": true,
"constant_schedule_with_warmup": { "n_warmup_steps": 0 },
"linear_schedule_with_warmup": { "n_warmup_steps": 0 },
Expand All @@ -33,6 +34,10 @@
"n_warmup_steps": 0,
"power": 1.0,
"min_lr": 0.0
},
"linear_warmup_cosine_annealing_lr": {
"warmup_epochs": 0,
"max_epochs": 20
}
},
"arch": {
Expand All @@ -43,15 +48,18 @@
"num_frames": 4,
"pretrained": true,
"time_init": "zeros",
"input_res": 224
"input_res": 224,
"color_classes": 8,
"type_classes": 6,
"motion_classes": 4
},
"text_params": {
"model": "distilbert-base-uncased",
"pretrained": true,
"input": "text"
},
"projection": "minimal",
"load_checkpoint": "/kaggle/input/frozenintime-ckpt/cc-webvid2m-4f_stformer_b_16_224.pth.tar"
"load_checkpoint": "./checkpoint/archive/cc-webvid2m-4f_stformer_b_16_224.pth.tar"
},
"text_head_setting": {
"type": "ContextualizedWeightedHead",
Expand Down
122 changes: 103 additions & 19 deletions dataloader/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,57 @@
import torchvision.transforms as transforms

class Track2CustomDataset(Dataset):
def __init__(self, video_params, data_tracks, tokenizer, max_len, transforms, config, train=True):
def __init__(self, video_params, data_tracks, tokenizer, max_len, transforms, config, mode="train"):

self.samples = data_tracks
self.transforms = transforms
self.tokenizer = tokenizer
self.max_len = max_len
self.video_params = video_params
self.train = train
self.mode = mode
self.config = config

def __len__(self):
return len(self.samples)

def __getitem__(self, index):

sample = self.samples.iloc[index]
frames_path, boxes, nl_descriptions, frames_dir = sample['frames'], sample['boxes'], sample['nl'], sample['frames_dir']

if self.mode == "train":

final, motion, motion_line = self.image_features(sample)
text_inputs = self.lang_features(sample)

sample = {
'text': text_inputs,
'video': final,
'motion': motion,
'motion_line': motion_line,
'color_label': sample['colors'],
'type_label': sample['type'],
'motion_label': sample['motion']

}

return sample

if self.mode == "infer_text":

text_inputs = self.lang_features(sample)
return {'text': text_inputs}

if self.mode == "infer_video":
final, motion, motion_line = self.image_features(sample)
return {
'video': final,
'motion': motion,
'motion_line': motion_line
}

def image_features(self, sample):
frames_path, boxes = sample['frames'], sample['boxes']

veh_imgs, motion_line, motion = get_motion_img(os.path.join(self.config['general_config']['data_dir'], frames_dir), frames_path, boxes)
veh_imgs, motion_line, motion = get_motion_img(self.config['general_config']['data_dir'], frames_path, boxes, self.config['arch']['base_settings']['video_params']['num_frames'])

if self.transforms:
veh_imgs = [self.transforms(img.astype(np.float32)) for img in veh_imgs]
Expand All @@ -35,7 +67,11 @@ def __getitem__(self, index):

final = torch.zeros([self.video_params['num_frames'], 3, self.video_params['input_res'], self.video_params['input_res']])
final[: veh_imgs.shape[0]] = veh_imgs


return final, motion, motion_line

def lang_features(self, sample):
nl_descriptions = sample['nl']
text_inputs = []
for idx, text in enumerate(nl_descriptions):
# print("text: ", text, ", idx: ", idx)
Expand All @@ -49,25 +85,36 @@ def __getitem__(self, index):
text_inputs.append({
'input_ids': torch.LongTensor(tokenized_inp['input_ids']),
'attention_mask': torch.LongTensor(tokenized_inp['input_ids'])
})
})

sample = {
'text': text_inputs,
'video': final,
'motion': motion,
'motion_line': motion_line
}

return sample
return text_inputs

def videotext_collate_fn(batch_data):
frames = torch.stack([item['video'] for item in batch_data])
motion = torch.stack([item['motion'] for item in batch_data])
motion_line = torch.stack([item['motion_line'] for item in batch_data])
input_ids = torch.stack([cap['input_ids'] for item in batch_data for cap in item['text']])
attention_mask = torch.stack([cap['attention_mask'] for item in batch_data for cap in item['text']])
color_label = torch.LongTensor([item['color_label'] for item in batch_data])
type_label = torch.LongTensor([item['type_label'] for item in batch_data])
motion_label = torch.LongTensor([item['motion_label'] for item in batch_data])

return {'video': frames, 'text': {'input_ids': input_ids, 'attention_mask': attention_mask}, 'motion': motion, 'motion_line': motion_line,
'color_label': color_label, 'type_label': type_label, 'motion_label': motion_label}

def text_collate_fn(batch_data):
input_ids = torch.stack([cap['input_ids'] for item in batch_data for cap in item['text']])
attention_mask = torch.stack([cap['attention_mask'] for item in batch_data for cap in item['text']])

return {'video': frames, 'text': {'input_ids': input_ids, 'attention_mask': attention_mask}, 'motion': motion, 'motion_line': motion_line}
return {'text': {'input_ids': input_ids, 'attention_mask': attention_mask}}

def video_collate_fn(batch_data):
frames = torch.stack([item['video'] for item in batch_data])
motion = torch.stack([item['motion'] for item in batch_data])
motion_line = torch.stack([item['motion_line'] for item in batch_data])
return {'video': frames, 'motion': motion, 'motion_line': motion_line}



def get_transforms(img_size, train, size=1):
if train:
Expand Down Expand Up @@ -103,14 +150,13 @@ def get_train_dataloader(config, df):
)

return dataloader

def get_valid_dataloader(config, df):
dataset = Track2CustomDataset(data_tracks=df,
video_params=config.arch.base_settings.video_params,
tokenizer=config.general_config.tokenizer,
max_len=int(config.general_config.max_len),
transforms=get_transforms(config.arch.base_settings.video_params.input_res, train=False),
train=False,
config=config)
dataloader = torch.utils.data.DataLoader(
dataset,
Expand All @@ -122,4 +168,42 @@ def get_valid_dataloader(config, df):
drop_last=False
)

return dataloader
return dataloader

def get_infer_dataloader(config, df_video, df_text):
text_dataset = Track2CustomDataset(data_tracks=df_text,
video_params=config.arch.base_settings.video_params,
tokenizer=config.general_config.tokenizer,
max_len=int(config.general_config.max_len),
transforms=get_transforms(config.arch.base_settings.video_params.input_res, train=False),
config=config,
mode="infer_text")

video_dataset = Track2CustomDataset(data_tracks=df_video,
video_params=config.arch.base_settings.video_params,
tokenizer=config.general_config.tokenizer,
max_len=int(config.general_config.max_len),
transforms=get_transforms(config.arch.base_settings.video_params.input_res, train=False),
config=config,
mode="infer_video")

text_dataloader = torch.utils.data.DataLoader(
text_dataset,
batch_size=config.general_config.valid_batch_size,
num_workers=config.general_config.n_workers,
collate_fn=text_collate_fn,
shuffle=False,
pin_memory=True,
drop_last=False
)

video_dataloader = torch.utils.data.DataLoader(
video_dataset,
batch_size=config.general_config.valid_batch_size,
num_workers=config.general_config.n_workers,
collate_fn=video_collate_fn,
shuffle=False,
pin_memory=True,
drop_last=False
)
return video_dataloader, text_dataloader
43 changes: 24 additions & 19 deletions dataloader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,37 @@ def bb_intersection_over_union(boxA, boxB):
return iou

def get_motion_img(data_dir, img_path_list, boxes_list, num_frames):
w, h, c = cv2.imread(os.path.join(data_dir, img_path_list[0][2:])).shape
motion_img = np.zeros((w, h, c), dtype=np.int16)
first = cv2.imread(os.path.join(data_dir, img_path_list[0][2:]))
w, h, c = first.shape
motion_img = first
line_motion_img = np.zeros((w, h, c), dtype=np.int16)

center_points = []

prev_boxes = None
box_indices = []

for idx, boxes in enumerate(boxes_list):
x, y, w, h = boxes


prev_box = []

for idx, img_path in enumerate(img_path_list):
if prev_boxes is None:
prev_boxes = [x, y, x+w, y+h]
box_indices.append(idx)
else:
curr_boxes = [x, y, x+w, y+h]

if bb_intersection_over_union(prev_boxes, curr_boxes) < 0.05:
prev_boxes = curr_boxes
box_indices.append(idx)

# print("BOXLEN: ", len(box_indices))
for idx in box_indices:
img_path = img_path_list[idx]
img = cv2.imread(os.path.join(data_dir, img_path[2:]))

x, y, w, h = boxes_list[idx]
context_img = img[y:y+h, x:x+w, :]

if len(prev_box) == 0:
prev_box = [x, y, x+w, y+h]
else:
curr_box = [x, y, x+w, y+h]
if bb_intersection_over_union(prev_box, curr_box) > 0.5:
continue
else:
prev_box = curr_box

x, y, w, h = boxes_list[idx]

motion_img[y:y+h, x:x+w, :] = context_img
center_points.append((int(x+ w/2),int(y + h/2)))

Expand All @@ -77,10 +82,10 @@ def get_motion_img(data_dir, img_path_list, boxes_list, num_frames):
for point1, point2 in zip(center_points, center_points[1:]):
cv2.line(line_motion_img, point1, point2, [255, 255, 255], 80)

indexes = sample_frames(num_frames, len(img_path_list), sample='uniform')
frame_indexes = sample_frames(num_frames, len(img_path_list), sample='uniform')

context_images = []
for idx in indexes:
for idx in frame_indexes:
img = cv2.imread(os.path.join(data_dir, img_path_list[idx][2:]))
x, y, w, h = boxes_list[idx]
context_img = img[y:y+h, x:x+w, :]
Expand Down
2 changes: 1 addition & 1 deletion datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.data import Dataset
import torch.nn.functional as F
import torchvision
from utils import get_logger
from utils_ import get_logger

def default_loader(path):
return Image.open(path).convert('RGB')
Expand Down
Loading

0 comments on commit c3507c2

Please sign in to comment.