diff --git a/xinference/model/image/stable_diffusion/core.py b/xinference/model/image/stable_diffusion/core.py index 3330fd9395..4ba5e7663f 100644 --- a/xinference/model/image/stable_diffusion/core.py +++ b/xinference/model/image/stable_diffusion/core.py @@ -412,12 +412,17 @@ def _get_scheduler(model: Any, sampler_name: str): else: raise ValueError(f"Unknown sampler: {sampler_name}") - @staticmethod @contextlib.contextmanager - def _reset_when_done(model: Any, sampler_name: str): + def _reset_when_done(self, model: Any, sampler_name: str): assert model is not None + if self._model_spec is None: + is_flux_sampler = False + else: + is_flux_sampler = "FLUX" in self._model_spec.model_name scheduler = DiffusionModel._get_scheduler(model, sampler_name) - if scheduler: + # When the model is not flux + should_swap_scheduler = scheduler is not None and not is_flux_sampler + if should_swap_scheduler: default_scheduler = model.scheduler model.scheduler = scheduler try: