diff --git a/paddlex/inference/models/common/vlm/transformers/model_utils.py b/paddlex/inference/models/common/vlm/transformers/model_utils.py index aaa7600543..8c533f789d 100644 --- a/paddlex/inference/models/common/vlm/transformers/model_utils.py +++ b/paddlex/inference/models/common/vlm/transformers/model_utils.py @@ -296,6 +296,10 @@ def _load_state_dict_into_model( warnings.filterwarnings("ignore", message=r".*paddle.to_tensor.*") if convert_from_hf: try: + # Convert bfloat16 to float32 if needed to resolve dtype mismatch + for key, value in state_dict.items(): + if hasattr(value, 'dtype') and value.dtype == paddle.bfloat16: + state_dict[key] = value.astype(paddle.float32) model_to_load.set_hf_state_dict(state_dict) except NotImplementedError: pass