1717import numpy as np
1818
1919from ....modules .text_detection .model_list import MODELS
20+ from ....utils .device import TemporaryDeviceChanger
2021from ....utils .func_register import FuncRegister
2122from ...common .batch_sampler import ImageBatchSampler
2223from ...common .reader import ReadImage
24+ from ...utils .misc import is_bfloat16_available , is_float16_available
2325from ..base import BasePredictor
2426from ..common import ToBatch , ToCHWImage
2527from .processors import DBPostProcess , DetResizeForTest , NormalizeImage
@@ -54,6 +56,15 @@ def __init__(
5456 self .unclip_ratio = unclip_ratio
5557 self .input_shape = input_shape
5658 self .max_side_limit = max_side_limit
59+
60+ self .device = kwargs .get ("device" , None )
61+ if is_bfloat16_available (self .device ):
62+ self .dtype = "bfloat16"
63+ elif is_float16_available (self .device ):
64+ self .dtype = "float16"
65+ else :
66+ self .dtype = "float32"
67+
5768 self .pre_tfs , self .infer , self .post_op = self ._build ()
5869
5970 def _build_batch_sampler (self ):
@@ -80,16 +91,18 @@ def _build(self):
8091 if self .model_name == "PP-OCRv5_mobile_det" :
8192 from .modeling import PPOCRV5MobileDet
8293
83- infer = PPOCRV5MobileDet .from_pretrained (
84- self .model_dir , use_safetensors = True , convert_from_hf = True
85- )
94+ with TemporaryDeviceChanger (self .device ):
95+ infer = PPOCRV5MobileDet .from_pretrained (
96+ self .model_dir , use_safetensors = True , convert_from_hf = True
97+ )
8698 infer .eval ()
8799 elif self .model_name == "PP-OCRv5_server_det" :
88100 from .modeling import PPOCRV5ServerDet
89101
90- infer = PPOCRV5ServerDet .from_pretrained (
91- self .model_dir , use_safetensors = True , convert_from_hf = True
92- )
102+ with TemporaryDeviceChanger (self .device ):
103+ infer = PPOCRV5ServerDet .from_pretrained (
104+ self .model_dir , use_safetensors = True , convert_from_hf = True
105+ )
93106 infer .eval ()
94107 else :
95108 raise RuntimeError (
@@ -122,7 +135,8 @@ def process(
122135 batch_imgs = self .pre_tfs ["Normalize" ](imgs = batch_imgs )
123136 batch_imgs = self .pre_tfs ["ToCHW" ](imgs = batch_imgs )
124137 x = self .pre_tfs ["ToBatch" ](imgs = batch_imgs )
125- batch_preds = self .infer (x = x )
138+ with TemporaryDeviceChanger (self .device ):
139+ batch_preds = self .infer (x = x )
126140 polys , scores = self .post_op (
127141 batch_preds ,
128142 batch_shapes ,
0 commit comments