From e03697e46e935ecfad4e91306072f38399316a8a Mon Sep 17 00:00:00 2001 From: "Scott H. Hawley" Date: Tue, 21 Oct 2025 20:52:27 -0500 Subject: [PATCH 1/2] Added MPS option for device --- stable_audio_tools/interface/gradio.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_audio_tools/interface/gradio.py b/stable_audio_tools/interface/gradio.py index cc234166..fb76fe66 100644 --- a/stable_audio_tools/interface/gradio.py +++ b/stable_audio_tools/interface/gradio.py @@ -363,7 +363,7 @@ def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pret else: model_config = None - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device) if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint": @@ -375,4 +375,4 @@ def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pret elif model_type == "lm": ui = create_lm_ui(model_config) - return ui \ No newline at end of file + return ui From 486abafdfa58561234d6b447c5c6d725a81d8219 Mon Sep 17 00:00:00 2001 From: "Scott H. Hawley" Date: Tue, 21 Oct 2025 20:55:31 -0500 Subject: [PATCH 2/2] Update device selection for compatibility with MPS --- stable_audio_tools/models/conditioners.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_audio_tools/models/conditioners.py b/stable_audio_tools/models/conditioners.py index 64f6e3a8..96b9af21 100644 --- a/stable_audio_tools/models/conditioners.py +++ b/stable_audio_tools/models/conditioners.py @@ -226,7 +226,7 @@ def __init__(self, project_out: bool = False): super().__init__(512, output_dim, project_out=project_out) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') # Suppress logging from transformers previous_level = logging.root.manager.disable @@ -758,4 +758,4 @@ def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.An else: raise ValueError(f"Unknown conditioner type: {conditioner_type}") - return MultiConditioner(conditioners, default_keys=default_keys, pre_encoded_keys=pre_encoded_keys) \ No newline at end of file + return MultiConditioner(conditioners, default_keys=default_keys, pre_encoded_keys=pre_encoded_keys)