|
1 | 1 | import errno |
2 | 2 | from argparse import ArgumentParser, ArgumentTypeError |
3 | 3 | from pathlib import Path |
| 4 | +from typing import List, Union |
4 | 5 |
|
5 | 6 | import torch |
| 7 | +from PIL import Image, UnidentifiedImageError |
6 | 8 |
|
7 | 9 | from wpodnet import Predictor, load_wpodnet_from_checkpoint |
8 | | -from wpodnet.stream import ImageStreamer |
| 10 | + |
| 11 | + |
| 12 | +def list_image_paths(p: Union[str, Path]) -> List[Path]: |
| 13 | + """ |
| 14 | + List all images in a directory. |
| 15 | +
|
| 16 | + Args: |
| 17 | + path (Union[str, Path]): The path to the directory containing images. |
| 18 | +
|
| 19 | + Returns: |
| 20 | + Generator[Image.Image]: A generator of PIL Image objects. |
| 21 | + """ |
| 22 | + p = Path(p) |
| 23 | + if not p.is_dir(): |
| 24 | + raise FileNotFoundError(errno.ENOTDIR, "No such directory", args.save_annotated) |
| 25 | + |
| 26 | + image_paths: List[Path] = [] |
| 27 | + for f in p.glob("**/*"): |
| 28 | + try: |
| 29 | + with Image.open(f) as image: |
| 30 | + image.verify() |
| 31 | + image_paths.append(f) |
| 32 | + except UnidentifiedImageError: |
| 33 | + pass |
| 34 | + return image_paths |
| 35 | + |
9 | 36 |
|
10 | 37 | if __name__ == "__main__": |
11 | 38 | parser = ArgumentParser() |
|
56 | 83 |
|
57 | 84 | predictor = Predictor(model) |
58 | 85 |
|
59 | | - streamer = ImageStreamer(args.source) |
60 | | - for i, image in enumerate(streamer): |
| 86 | + source = Path(args.source) |
| 87 | + if source.is_file(): |
| 88 | + image_paths = [source] |
| 89 | + elif source.is_dir(): |
| 90 | + image_paths = list_image_paths(source) |
| 91 | + else: |
| 92 | + raise FileNotFoundError(errno.ENOENT, "No such file or directory", args.source) |
| 93 | + |
| 94 | + for i, image_path in enumerate(image_paths): |
| 95 | + image = Image.open(image_path) |
61 | 96 | prediction = predictor.predict(image, scaling_ratio=args.scale) |
62 | 97 |
|
63 | 98 | print(f"Prediction #{i}") |
|
0 commit comments