Skip to content

Commit 676ea1e

Browse files
committed
Reproducible code for MGTN
0 parents  commit 676ea1e

22 files changed

+1569
-0
lines changed

README.md

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Modular Graph Transformer Networks (MGTN)
2+
This project implements the multi-learning based on Modular Graph Transformer Networks (MGTN).
3+
4+
### Requirements
5+
Please, install the following packages
6+
- numpy
7+
- pytorch (1.*)
8+
- torchnet
9+
- torchvision
10+
- tqdm
11+
- networkx
12+
13+
### Download best checkpoints
14+
checkpoint/coco/mgtn_final_86.9762.pth.tar ([Dropbox](https://www.dropbox.com/s/fr2286gwxsg80kq/mgtn_final_86.9762.pth.tar?dl=0))
15+
16+
### Performance
17+
18+
| Method | mAP | CP | CR | CF1 | OP | OR | OF1 |
19+
| --------------------- | ---------- | --------- | ---------- | ---------- | --------- | ---------- | ---------- |
20+
| CNN\-RNN | 61\.2 | \- | \- | \- | \- | \- | \- |
21+
| SRN | 77\.1 | 81\.6 | 65\.4 | 71\.2 | 82\.7 | 69\.9 | 75\.8 |
22+
| Baseline\(ResNet101\) | 77\.3 | 80\.2 | 66\.7 | 72\.8 | 83\.9 | 70\.8 | 76\.8 |
23+
| Multi\-Evidence || 80\.4 | 70\.2 | 74\.9 | 85\.2 | 72\.5 | 78\.4 | |
24+
| ML\-GCN (2019) | 82\.4 | 84\.4 | 71\.4 | 77\.4 | 85\.8 | 74\.5 | 79\.8 |
25+
ML-GCN (ResNeXt50 with ImageNet) | 86.2 | 85.8 | 77.3 | 81.3 | 86.2 | 79.7 | 82.8 |
26+
| A\-GCN | 83\.1 | 84\.7 | 72\.3 | 78\.0 | 85\.6 | 75\.5 | 80\.3 |
27+
| KSSNet | 83\.7 | 84\.6 | 73\.2 | 77\.2 | 87\.8 | 76\.2 | 81\.5 |
28+
| SGTN (Our**) | 86\.6 | 77\.2 | **82\.2** | 79\.6 | 76\.0 | **82\.6** | 79\.2 |
29+
| **MGTN\(Base\)** | 86\.91 | **89.38** | 74.46 | 81.25 | **90.91** | 76.27 | 82.95 |
30+
| **MGTN\(Final\}** | **86\.98** | 86\.11 | 77\.85 | **81\.77** | 87\.71 | 79\.40 | **83\.35** |
31+
32+
** SGTN (Our): https://github.com/ReML-AI/sgtn
33+
34+
### TGCN on COCO
35+
36+
```sh
37+
python mgtn.py data/coco --image-size 448 --workers 8 --batch-size 32 --lr 0.03 --learning-rate-decay 0.1 --epoch_step 20 30 --embedding model/embedding/coco_glove_word2vec_80x300_ec.pkl --adj-strong-threshold 0.4 --adj-weak-threshold 0.2 --device_ids 0 1 2 3
38+
```
39+
40+
### How to cite this work?
41+
```
42+
@inproceedings{Nguyen:AAAI:2021,
43+
author = {Nguyen, Hoang D. and Vu, Xuan-Son and Le, Duc-Trong},
44+
title = {Modular Graph Transformer Networks for Multi-Label Image Classification},
45+
booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence},
46+
series = {AAAI '21},
47+
year = {2021},
48+
publisher = {AAAI}
49+
}
50+
```
51+
52+
53+
54+
## Reference
55+
This project is based on the following implementations:
56+
57+
- https://github.com/durandtibo/wildcat.pytorch
58+
- https://github.com/tkipf/pygcn
59+
- https://github.com/Megvii-Nanjing/ML_GCN/
60+
- https://github.com/seongjunyun/Graph_Transformer_Networks
61+
62+

coco.py

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import torch.utils.data as data
2+
import json
3+
import os
4+
import subprocess
5+
from PIL import Image
6+
import numpy as np
7+
import torch
8+
import pickle
9+
from util import *
10+
11+
urls = {'train_img': 'http://images.cocodataset.org/zips/train2014.zip',
12+
'val_img': 'http://images.cocodataset.org/zips/val2014.zip',
13+
'annotations': 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip'}
14+
15+
16+
def download_coco2014(root, phase):
17+
if not os.path.exists(root):
18+
os.makedirs(root)
19+
tmpdir = os.path.join(root, 'tmp/')
20+
data = os.path.join(root, 'data/')
21+
if not os.path.exists(data):
22+
os.makedirs(data)
23+
if not os.path.exists(tmpdir):
24+
os.makedirs(tmpdir)
25+
if phase == 'train':
26+
filename = 'train2014.zip'
27+
elif phase == 'val':
28+
filename = 'val2014.zip'
29+
cached_file = os.path.join(tmpdir, filename)
30+
cached_file = os.path.join(data, "train2014")
31+
if not os.path.exists(cached_file):
32+
print('Downloading: "{}" to {}\n'.format(
33+
urls[phase + '_img'], cached_file))
34+
os.chdir(tmpdir)
35+
subprocess.call('wget ' + urls[phase + '_img'], shell=True)
36+
os.chdir(root)
37+
# extract file
38+
img_data = os.path.join(data, filename.split('.')[0])
39+
if not os.path.exists(img_data):
40+
print('[dataset] Extracting tar file {file} to {path}'.format(
41+
file=cached_file, path=data))
42+
command = 'unzip {} -d {}'.format(cached_file, data)
43+
os.system(command)
44+
print('[dataset] Done!')
45+
46+
# train/val images/annotations
47+
cached_file = os.path.join(tmpdir, 'annotations_trainval2014.zip')
48+
cached_file = os.path.join(data, 'train_anno.json')
49+
if not os.path.exists(cached_file):
50+
print('Downloading: "{}" to {}\n'.format(
51+
urls['annotations'], cached_file))
52+
os.chdir(tmpdir)
53+
subprocess.Popen('wget ' + urls['annotations'], shell=True)
54+
os.chdir(root)
55+
annotations_data = os.path.join(data, 'annotations')
56+
if not os.path.exists(annotations_data):
57+
print('[dataset] Extracting tar file {file} to {path}'.format(
58+
file=cached_file, path=data))
59+
command = 'unzip {} -d {}'.format(cached_file, data)
60+
os.system(command)
61+
print('[annotation] Done!')
62+
63+
anno = os.path.join(data, '{}_anno.json'.format(phase))
64+
img_id = {}
65+
annotations_id = {}
66+
if not os.path.exists(anno):
67+
annotations_file = json.load(
68+
open(os.path.join(annotations_data, 'instances_{}2014.json'.format(phase))))
69+
annotations = annotations_file['annotations']
70+
category = annotations_file['categories']
71+
category_id = {}
72+
for cat in category:
73+
category_id[cat['id']] = cat['name']
74+
cat2idx = categoty_to_idx(sorted(category_id.values()))
75+
images = annotations_file['images']
76+
for annotation in annotations:
77+
if annotation['image_id'] not in annotations_id:
78+
annotations_id[annotation['image_id']] = set()
79+
annotations_id[annotation['image_id']].add(
80+
cat2idx[category_id[annotation['category_id']]])
81+
for img in images:
82+
if img['id'] not in annotations_id:
83+
continue
84+
if img['id'] not in img_id:
85+
img_id[img['id']] = {}
86+
img_id[img['id']]['file_name'] = img['file_name']
87+
img_id[img['id']]['labels'] = list(annotations_id[img['id']])
88+
anno_list = []
89+
for k, v in img_id.items():
90+
anno_list.append(v)
91+
json.dump(anno_list, open(anno, 'w'))
92+
if not os.path.exists(os.path.join(data, 'category.json')):
93+
json.dump(cat2idx, open(os.path.join(data, 'category.json'), 'w'))
94+
del img_id
95+
del anno_list
96+
del images
97+
del annotations_id
98+
del annotations
99+
del category
100+
del category_id
101+
print('[json] Done!')
102+
103+
104+
def categoty_to_idx(category):
105+
cat2idx = {}
106+
for cat in category:
107+
cat2idx[cat] = len(cat2idx)
108+
return cat2idx
109+
110+
111+
class COCO2014(data.Dataset):
112+
def __init__(self, root, transform=None, phase='train', inp_name=None):
113+
self.root = root
114+
self.phase = phase
115+
self.img_list = []
116+
self.transform = transform
117+
download_coco2014(root, phase)
118+
self.get_anno()
119+
self.num_classes = len(self.cat2idx)
120+
121+
with open(inp_name, 'rb') as f:
122+
self.inp = pickle.load(f)
123+
self.inp_name = inp_name
124+
125+
def get_anno(self):
126+
list_path = os.path.join(
127+
self.root, 'data', '{}_anno.json'.format(self.phase))
128+
self.img_list = json.load(open(list_path, 'r'))
129+
self.cat2idx = json.load(
130+
open(os.path.join(self.root, 'data', 'category.json'), 'r'))
131+
132+
def __len__(self):
133+
return len(self.img_list)
134+
135+
def __getitem__(self, index):
136+
item = self.img_list[index]
137+
return self.get(item)
138+
139+
def get(self, item):
140+
filename = item['file_name']
141+
labels = sorted(item['labels'])
142+
img = Image.open(os.path.join(self.root, 'data',
143+
'{}2014'.format(self.phase), filename)).convert('RGB')
144+
if self.transform is not None:
145+
img = self.transform(img)
146+
target = np.zeros(self.num_classes, np.float32) - 1
147+
target[labels] = 1
148+
return (img, filename, self.inp), target

0 commit comments

Comments
 (0)