Skip to content

Commit fd6890d

Browse files
committed
refine predict.py
1 parent 82d5a03 commit fd6890d

File tree

2 files changed

+149
-78
lines changed

2 files changed

+149
-78
lines changed

predict.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,56 +7,51 @@
77
from wpodnet import Predictor, load_wpodnet_from_checkpoint
88
from wpodnet.stream import ImageStreamer
99

10-
if __name__ == '__main__':
10+
if __name__ == "__main__":
1111
parser = ArgumentParser()
12+
parser.add_argument("source", type=str, help="the path to the image")
1213
parser.add_argument(
13-
'source',
14-
type=str,
15-
help='the path to the image'
16-
)
17-
parser.add_argument(
18-
'-w', '--weight',
19-
type=str,
20-
required=True,
21-
help='the path to the model weight'
14+
"-w", "--weight", type=str, required=True, help="the path to the model weight"
2215
)
2316
parser.add_argument(
24-
'--scale',
17+
"--scale",
2518
type=float,
2619
default=1.0,
27-
help='adjust the scaling ratio. default to 1.0.'
20+
help="adjust the scaling ratio. default to 1.0.",
2821
)
2922
parser.add_argument(
30-
'--save-annotated',
23+
"--save-annotated",
3124
type=str,
32-
help='save the annotated image at the given folder'
25+
help="save the annotated image at the given folder",
3326
)
3427
parser.add_argument(
35-
'--save-warped',
36-
type=str,
37-
help='save the warped image at the given folder'
28+
"--save-warped", type=str, help="save the warped image at the given folder"
3829
)
3930
args = parser.parse_args()
4031

4132
if args.scale <= 0.0:
42-
raise ArgumentTypeError(message='scale must be greater than 0.0')
33+
raise ArgumentTypeError(message="scale must be greater than 0.0")
4334

4435
if args.save_annotated is not None:
4536
save_annotated = Path(args.save_annotated)
4637
if not save_annotated.is_dir():
47-
raise FileNotFoundError(errno.ENOTDIR, 'No such directory', args.save_annotated)
38+
raise FileNotFoundError(
39+
errno.ENOTDIR, "No such directory", args.save_annotated
40+
)
4841
else:
4942
save_annotated = None
5043

5144
if args.save_warped is not None:
5245
save_warped = Path(args.save_warped)
5346
if not save_warped.is_dir():
54-
raise FileNotFoundError(errno.ENOTDIR, 'No such directory', args.save_warped)
47+
raise FileNotFoundError(
48+
errno.ENOTDIR, "No such directory", args.save_warped
49+
)
5550
else:
5651
save_warped = None
5752

5853
# Prepare for the model
59-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
54+
device = "cuda" if torch.cuda.is_available() else "cpu"
6055
model = load_wpodnet_from_checkpoint(args.weight).to(device)
6156

6257
predictor = Predictor(model)
@@ -65,20 +60,22 @@
6560
for i, image in enumerate(streamer):
6661
prediction = predictor.predict(image, scaling_ratio=args.scale)
6762

68-
print(f'Prediction #{i}')
69-
print(' bounds', prediction.bounds.tolist())
70-
print(' confidence', prediction.confidence)
63+
print(f"Prediction #{i}")
64+
print(" bounds", prediction.bounds)
65+
print(" confidence", prediction.confidence)
7166

7267
if save_annotated:
7368
annotated_path = save_annotated / Path(image.filename).name
74-
annotated = prediction.annotate()
75-
annotated.save(annotated_path)
76-
print(f'Saved the annotated image at {annotated_path}')
69+
70+
canvas = image.copy()
71+
prediction.annotate(canvas, outline="red")
72+
canvas.save(annotated_path)
73+
print(f"Saved the annotated image at {annotated_path}")
7774

7875
if save_warped:
7976
warped_path = save_warped / Path(image.filename).name
80-
warped = prediction.warp()
77+
warped = prediction.warp(image)
8178
warped.save(warped_path)
82-
print(f'Saved the warped image at {warped_path}')
79+
print(f"Saved the warped image at {warped_path}")
8380

8481
print()

wpodnet/backend.py

Lines changed: 123 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import List, Tuple
1+
from dataclasses import dataclass
2+
from typing import List, Optional, Tuple
23

34
import numpy as np
45
import torch
@@ -8,66 +9,116 @@
89
from .model import WPODNet
910

1011

12+
@dataclass
1113
class Prediction:
12-
def __init__(self, image: Image.Image, bounds: np.ndarray, confidence: float):
13-
self.image = image
14-
self.bounds = bounds
15-
self.confidence = confidence
16-
17-
def _get_perspective_coeffs(self, width: int, height: int) -> List[float]:
18-
# Get the perspective matrix
19-
src_points = self.bounds.tolist()
20-
dst_points = [[0, 0], [width, 0], [width, height], [0, height]]
21-
return _get_perspective_coeffs(src_points, dst_points)
22-
23-
def annotate(self, outline: str = 'red', width: int = 3) -> Image.Image:
24-
canvas = self.image.copy()
14+
"""
15+
The prediction result from WPODNet.
16+
17+
Attributes:
18+
bounds (List[Tuple[int, int]]): The bounding coordinates of the detected license plate. Must be a list of 4 points (x, y).
19+
confidence (float): The confidence score of the detection. Must be between 0.0 and 1.0.
20+
"""
21+
22+
bounds: List[Tuple[int, int]]
23+
confidence: float
24+
25+
def __post_init__(self):
26+
if len(self.bounds) != 4:
27+
raise ValueError(
28+
f"expected bounds to have 4 points, got {len(self.bounds)} points"
29+
)
30+
if self.confidence < 0 or self.confidence > 1:
31+
raise ValueError(
32+
f"confidence must be between 0.0 and 1.0, got {self.confidence}"
33+
)
34+
35+
def annotate(
36+
self,
37+
canvas: Image.Image,
38+
fill: Optional[str] = None,
39+
outline: Optional[str] = None,
40+
width: int = 1,
41+
) -> None:
42+
"""
43+
Annotates the image with the bounding polygon.
44+
45+
Args:
46+
canvas (PIL.Image.Image): The image to be annotated.
47+
fill (Optional[str]): The fill color for the polygon. Defaults to None.
48+
outline (Optional[str]): The outline color for the polygon. Defaults to None.
49+
width (int): The width of the outline. Defaults to 1.
50+
51+
Note:
52+
The arguments `fill`, `outline`, and `width` are passed to the `ImageDraw.Draw.polygon` method.
53+
See https://pillow.readthedocs.io/en/stable/reference/ImageDraw.html#PIL.ImageDraw.ImageDraw.polygon.
54+
"""
2555
drawer = ImageDraw.Draw(canvas)
26-
drawer.polygon(
27-
[(x, y) for x, y in self.bounds],
28-
outline=outline,
29-
width=width
56+
drawer.polygon(self.bounds, fill=fill, outline=outline, width=width)
57+
58+
def warp(self, canvas: Image.Image) -> Image.Image:
59+
"""
60+
Warps the image with perspective based on the bounding polygon.
61+
62+
Args:
63+
canvas (PIL.Image.Image): The image to be warped.
64+
65+
Returns:
66+
PIL.Image.Image: The warped image.
67+
"""
68+
coeffs = _get_perspective_coeffs(
69+
startpoints=self.bounds,
70+
endpoints=[
71+
(0, 0),
72+
(canvas.width, 0),
73+
(canvas.width, canvas.height),
74+
(0, canvas.height),
75+
],
3076
)
31-
return canvas
77+
return canvas.transform(
78+
(canvas.width, canvas.height), Image.Transform.PERSPECTIVE, coeffs
79+
)
80+
3281

33-
def warp(self, width: int = 208, height: int = 60) -> Image.Image:
34-
# Get the perspective matrix
35-
coeffs = self._get_perspective_coeffs(width, height)
36-
warped = self.image.transform((width, height), Image.PERSPECTIVE, coeffs)
37-
return warped
82+
Q = np.array(
83+
[
84+
[-0.5, 0.5, 0.5, -0.5],
85+
[-0.5, -0.5, 0.5, 0.5],
86+
[1.0, 1.0, 1.0, 1.0],
87+
]
88+
)
3889

3990

4091
class Predictor:
41-
_q = np.array([
42-
[-.5, .5, .5, -.5],
43-
[-.5, -.5, .5, .5],
44-
[1., 1., 1., 1.]
45-
])
46-
_scaling_const = 7.75
47-
_stride = 16
48-
49-
def __init__(self, wpodnet: WPODNet):
92+
"""A wrapper class for WPODNet to make predictions."""
93+
94+
def __init__(self, wpodnet: WPODNet) -> None:
95+
"""
96+
Args:
97+
wpodnet (WPODNet): The WPODNet model to use for prediction.
98+
"""
5099
self.wpodnet = wpodnet
51100
self.wpodnet.eval()
52101

53-
def _resize_to_fixed_ratio(self, image: Image.Image, dim_min: int, dim_max: int) -> Image.Image:
102+
def _resize_to_fixed_ratio(
103+
self, image: Image.Image, dim_min: int, dim_max: int
104+
) -> Image.Image:
54105
h, w = image.height, image.width
55106

56107
wh_ratio = max(h, w) / min(h, w)
57108
side = int(wh_ratio * dim_min)
58-
bound_dim = min(side + side % self._stride, dim_max)
109+
bound_dim = min(side + side % self.wpodnet.stride, dim_max)
59110

60111
factor = bound_dim / max(h, w)
61112
reg_w, reg_h = int(w * factor), int(h * factor)
62113

63-
# Ensure the both width and height are the multiply of `self._stride`
64-
reg_w_mod = reg_w % self._stride
114+
# Ensure the both width and height are the multiply of `self.wpodnet.stride`
115+
reg_w_mod = reg_w % self.wpodnet.stride
65116
if reg_w_mod > 0:
66-
reg_w += self._stride - reg_w_mod
117+
reg_w += self.wpodnet.stride - reg_w_mod
67118

68-
reg_h_mod = reg_h % self._stride
119+
reg_h_mod = reg_h % self.wpodnet.stride
69120
if reg_h_mod > 0:
70-
reg_h += self._stride - reg_h_mod
121+
reg_h += self.wpodnet.stride - reg_h_mod
71122

72123
return image.resize((reg_w, reg_h))
73124

@@ -82,32 +133,56 @@ def _inference(self, image: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
82133
# Convert to squeezed numpy array
83134
# grid_w: The number of anchors in row
84135
# grid_h: The number of anchors in column
85-
probs = np.squeeze(probs.cpu().numpy())[0] # (grid_h, grid_w)
136+
probs = np.squeeze(probs.cpu().numpy())[0] # (grid_h, grid_w)
86137
affines = np.squeeze(affines.cpu().numpy()) # (6, grid_h, grid_w)
87138

88139
return probs, affines
89140

90141
def _get_max_anchor(self, probs: np.ndarray) -> Tuple[int, int]:
91142
return np.unravel_index(probs.argmax(), probs.shape)
92143

93-
def _get_bounds(self, affines: np.ndarray, anchor_y: int, anchor_x: int, scaling_ratio: float = 1.0) -> np.ndarray:
144+
def _get_bounds(
145+
self,
146+
affines: np.ndarray,
147+
anchor_y: int,
148+
anchor_x: int,
149+
scaling_ratio: float = 1.0,
150+
) -> np.ndarray:
94151
# Compute theta
95152
theta = affines[:, anchor_y, anchor_x]
96153
theta = theta.reshape((2, 3))
97154
theta[0, 0] = max(theta[0, 0], 0.0)
98155
theta[1, 1] = max(theta[1, 1], 0.0)
99156

100157
# Convert theta into the bounding polygon
101-
bounds = np.matmul(theta, self._q) * self._scaling_const * scaling_ratio
158+
bounds = np.matmul(theta, Q) * self.wpodnet.scale_factor * scaling_ratio
102159

103160
# Normalize the bounds
104161
_, grid_h, grid_w = affines.shape
105-
bounds[0] = (bounds[0] + anchor_x + .5) / grid_w
106-
bounds[1] = (bounds[1] + anchor_y + .5) / grid_h
162+
bounds[0] = (bounds[0] + anchor_x + 0.5) / grid_w
163+
bounds[1] = (bounds[1] + anchor_y + 0.5) / grid_h
107164

108165
return np.transpose(bounds)
109166

110-
def predict(self, image: Image.Image, scaling_ratio: float = 1.0, dim_min: int = 512, dim_max: int = 768) -> Prediction:
167+
def predict(
168+
self,
169+
image: Image.Image,
170+
scaling_ratio: float = 1.0,
171+
dim_min: int = 512,
172+
dim_max: int = 768,
173+
) -> Prediction:
174+
"""
175+
Detect license plate in the image.
176+
177+
Args:
178+
image (Image.Image): The image to be detected.
179+
scaling_ratio (float): The scaling ratio of the resulting bounding polygon. Default to 1.0.
180+
dim_min (int): The minimum dimension of the resized image. Default to 512
181+
dim_max (int): The maximum dimension of the resized image. Default to 768
182+
183+
Returns:
184+
Prediction: The prediction result with highest confidence.
185+
"""
111186
orig_h, orig_w = image.height, image.width
112187

113188
# Resize the image to fixed ratio
@@ -130,7 +205,6 @@ def predict(self, image: Image.Image, scaling_ratio: float = 1.0, dim_min: int =
130205
bounds[:, 1] *= orig_h
131206

132207
return Prediction(
133-
image=image,
134-
bounds=bounds.astype(np.int32),
135-
confidence=max_prob.item()
208+
bounds=[(x, y) for x, y in np.int32(bounds).tolist()],
209+
confidence=max_prob.item(),
136210
)

0 commit comments

Comments
 (0)