diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 2af4ad0314c3..e2bbce7b0ead 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1422,7 +1422,18 @@ def test_float16_inference(self, expected_max_diff=5e-2): def test_save_load_float16(self, expected_max_diff=1e-2): components = self.get_dummy_components() for name, module in components.items(): - if hasattr(module, "half"): + # Account for components with _keep_in_fp32_modules + if hasattr(module, "_keep_in_fp32_modules") and module._keep_in_fp32_modules is not None: + for name, param in module.named_parameters(): + if any( + module_to_keep_in_fp32 in name.split(".") + for module_to_keep_in_fp32 in module._keep_in_fp32_modules + ): + param.data = param.data.to(torch_device).to(torch.float32) + else: + param.data = param.data.to(torch_device).to(torch.float16) + + elif hasattr(module, "half"): components[name] = module.to(torch_device).half() pipe = self.pipeline_class(**components)