Skip to content

Commit affbdf7

Browse files
committed
[Fix]: Add TemporaryDeviceChanger for text_detection/predictor
1 parent 3aaddce commit affbdf7

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

paddlex/inference/models/text_detection/predictor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ....utils.func_register import FuncRegister
2222
from ...common.batch_sampler import ImageBatchSampler
2323
from ...common.reader import ReadImage
24+
from ...utils.misc import is_bfloat16_available, is_float16_available
2425
from ..base import BasePredictor
2526
from ..common import ToBatch, ToCHWImage
2627
from .processors import DBPostProcess, DetResizeForTest, NormalizeImage
@@ -55,6 +56,15 @@ def __init__(
5556
self.unclip_ratio = unclip_ratio
5657
self.input_shape = input_shape
5758
self.max_side_limit = max_side_limit
59+
60+
self.device = kwargs.get("device", None)
61+
if is_bfloat16_available(self.device):
62+
self.dtype = "bfloat16"
63+
elif is_float16_available(self.device):
64+
self.dtype = "float16"
65+
else:
66+
self.dtype = "float32"
67+
5868
self.pre_tfs, self.infer, self.post_op = self._build()
5969

6070
def _build_batch_sampler(self):

0 commit comments

Comments
 (0)