Skip to content

Commit da69b5f

Browse files
committed
Fix temporary device changer
1 parent bb2b7fc commit da69b5f

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

paddlex/inference/models/text_detection/predictor.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
import numpy as np
1818

1919
from ....modules.text_detection.model_list import MODELS
20+
from ....utils.device import TemporaryDeviceChanger
2021
from ....utils.func_register import FuncRegister
2122
from ...common.batch_sampler import ImageBatchSampler
2223
from ...common.reader import ReadImage
24+
from ...utils.misc import is_bfloat16_available, is_float16_available
2325
from ..base import BasePredictor
2426
from ..common import ToBatch, ToCHWImage
2527
from .processors import DBPostProcess, DetResizeForTest, NormalizeImage
@@ -54,6 +56,15 @@ def __init__(
5456
self.unclip_ratio = unclip_ratio
5557
self.input_shape = input_shape
5658
self.max_side_limit = max_side_limit
59+
60+
self.device = kwargs.get("device", None)
61+
if is_bfloat16_available(self.device):
62+
self.dtype = "bfloat16"
63+
elif is_float16_available(self.device):
64+
self.dtype = "float16"
65+
else:
66+
self.dtype = "float32"
67+
5768
self.pre_tfs, self.infer, self.post_op = self._build()
5869

5970
def _build_batch_sampler(self):
@@ -80,16 +91,18 @@ def _build(self):
8091
if self.model_name == "PP-OCRv5_mobile_det":
8192
from .modeling import PPOCRV5MobileDet
8293

83-
infer = PPOCRV5MobileDet.from_pretrained(
84-
self.model_dir, use_safetensors=True, convert_from_hf=True
85-
)
94+
with TemporaryDeviceChanger(self.device):
95+
infer = PPOCRV5MobileDet.from_pretrained(
96+
self.model_dir, use_safetensors=True, convert_from_hf=True
97+
)
8698
infer.eval()
8799
elif self.model_name == "PP-OCRv5_server_det":
88100
from .modeling import PPOCRV5ServerDet
89101

90-
infer = PPOCRV5ServerDet.from_pretrained(
91-
self.model_dir, use_safetensors=True, convert_from_hf=True
92-
)
102+
with TemporaryDeviceChanger(self.device):
103+
infer = PPOCRV5ServerDet.from_pretrained(
104+
self.model_dir, use_safetensors=True, convert_from_hf=True
105+
)
93106
infer.eval()
94107
else:
95108
raise RuntimeError(
@@ -122,7 +135,8 @@ def process(
122135
batch_imgs = self.pre_tfs["Normalize"](imgs=batch_imgs)
123136
batch_imgs = self.pre_tfs["ToCHW"](imgs=batch_imgs)
124137
x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
125-
batch_preds = self.infer(x=x)
138+
with TemporaryDeviceChanger(self.device):
139+
batch_preds = self.infer(x=x)
126140
polys, scores = self.post_op(
127141
batch_preds,
128142
batch_shapes,

0 commit comments

Comments
 (0)