-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathinference_visdrone.py
84 lines (65 loc) · 2.52 KB
/
inference_visdrone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# # Mask R-CNN - Test on VisDrone Validation Dataset
import os
import sys
import random
import math
import numpy as np
import skimage.io
import matplotlib
import matplotlib.pyplot as plt
import mrcnn.utils
import mrcnn.model as modellib
import mrcnn.visualize as visualize
from mrcnn.config import Config
# Root directory of the project
ROOT_DIR = os.getcwd()
# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
# Local path to trained weights file
VISDRONE_MODEL_PATH = os.path.join(MODEL_DIR, "visdrone20180504T2326", "mask_rcnn_visdrone_0014.h5")
# Directory of images to run detection on
DATASET_ROOT_PATH = os.path.abspath(os.path.join(ROOT_DIR, "../data"))
IMAGE_DIR = os.path.join(DATASET_ROOT_PATH + "/VisDrone2018-DET-test-challenge/images")
OUTPUT_PATH = os.path.join(DATASET_ROOT_PATH + "VisDrone2018-DET-test-challenge/outputs")
if not os.path.exists(OUTPUT_PATH):
os.makedirs(OUTPUT_PATH)
## Configurations
class InferenceConfig(Config):
"""Derives from the base Config class and overrides values specific to the Mirror dataset"""
NAME = "VisDrone"
GPU_COUNT = 1
IMAGES_PER_GPU = 1
NUM_CLASSES = 2 + 10 # Mirror has only one (mirror) class
DETECTION_MIN_CONFIDENCE = 0.9
config = InferenceConfig()
config.display()
# ## Create Model and Load Trained Weights
# Create model object in inference mode.
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)
# Load weights trained by TaylorMei
model.load_weights(VISDRONE_MODEL_PATH, by_name=True)
# VISDRONE Class names
# Index of the class in the list is its ID. For example, to get ID of
# the teddy bear class, use: class_names.index('teddy bear')
class_names = ['ignored', 'pedestrian', 'people', 'bicycle', 'car', 'van',
'truck', 'tricyvle', 'awning-tricycle', 'bus', 'motor', 'others']
# ## Run Object Detection
imglist = os.listdir(IMAGE_DIR)
print(len(imglist))
for imgname in imglist:
image = skimage.io.imread(os.path.join(IMAGE_DIR, imgname))
# Run detection
results = model.detect([image], verbose=1)
# Visualize results
print(results[0])
r = results[0]
print(r['rois'], r['class_ids'], r['scores'])
# label_txt = open(OUTPUT_PATH + "/" + imgname + ".txt", "rw")
# try:
# count = 0
# for line in label_txt:
# (x, y, w, h, score, cls, truncation, occlusion) = line.split(',')
# label.append((x, y, w, h, score, cls, truncation, occlusion))
# count += 1
# finally:
# label_txt.close()