diff --git a/backend/ltx2_server.py b/backend/ltx2_server.py index ad1308489..6c44ec8f7 100644 --- a/backend/ltx2_server.py +++ b/backend/ltx2_server.py @@ -42,12 +42,97 @@ logging.basicConfig(level=logging.INFO, handlers=[console_handler]) logger = logging.getLogger(__name__) +# ============================================================ +# CUDA Fallback Handling (for non-CUDA PyTorch builds on MPS/CPU) +# ============================================================ + +def _setup_cuda_fallback() -> None: + """ + Monkey-patch torch.cuda functions when PyTorch was not compiled with CUDA + support (i.e., torch.version.cuda is None, as on MPS-only or CPU-only + wheels). + + The ltx-pipelines library calls torch.cuda.synchronize() unconditionally, + which raises "Torch not compiled with CUDA enabled" on non-CUDA builds. + This patch is intentionally limited to non-CUDA builds so that CUDA-capable + installations that happen to be running on CPU (e.g., driver temporarily + unavailable) still surface real CUDA misconfiguration errors instead of + silently no-oping them. + """ + # Only patch when PyTorch has no CUDA support compiled in. + if torch.version.cuda is not None: + return + + device_type = DEVICE.type + logger.info(f"Setup CUDA fallback for non-CUDA PyTorch build (device: {device_type})") + + # Create safe no-op implementations for CUDA functions + def safe_cuda_synchronize(device: object = None) -> None: + """No-op synchronize for non-CUDA devices; delegates to MPS when available.""" + if device_type == "mps": + try: + torch.mps.synchronize() + except (RuntimeError, AttributeError) as exc: + logger.debug("MPS synchronize fallback failed: %s", exc) + + def safe_cuda_empty_cache() -> None: + """No-op empty_cache for non-CUDA devices; delegates to MPS when available.""" + if device_type == "mps": + try: + torch.mps.empty_cache() + except (RuntimeError, AttributeError) as exc: + logger.debug("MPS empty_cache fallback failed: %s", exc) + + def safe_cuda_memory_reserved(device: object = None) -> int: + """Return 0 for memory reserved on non-CUDA devices.""" + return 0 + + def safe_cuda_memory_allocated(device: object = None) -> int: + """Return 0 for memory allocated on non-CUDA devices.""" + return 0 + + def safe_cuda_get_device_name(device: object = None) -> str: + """Return device name for non-CUDA devices.""" + if device_type == "mps" and hasattr(torch, 'mps'): + return "Apple Silicon MPS" + return "CPU" + + def safe_cuda_get_device_capability(device: object = None) -> tuple[int, int]: + """Return (0, 0) for non-CUDA devices.""" + return (0, 0) + + # Patch torch.cuda module + if not hasattr(torch.cuda, "_ltx_original_synchronize"): + # Store original functions if they exist + try: + torch.cuda._ltx_original_synchronize = torch.cuda.synchronize # type: ignore[attr-defined] + except AttributeError: + pass + + # Replace with safe implementations + torch.cuda.synchronize = safe_cuda_synchronize # type: ignore[assignment] + torch.cuda.empty_cache = safe_cuda_empty_cache # type: ignore[assignment] + torch.cuda.memory_reserved = safe_cuda_memory_reserved # type: ignore[assignment] + torch.cuda.memory_allocated = safe_cuda_memory_allocated # type: ignore[assignment] + torch.cuda.get_device_name = safe_cuda_get_device_name # type: ignore[assignment] + torch.cuda.get_device_capability = safe_cuda_get_device_capability # type: ignore[assignment] + + logger.info("CUDA fallback patch applied successfully") + + # ============================================================ # SageAttention Integration # ============================================================ use_sage_attention = os.environ.get("USE_SAGE_ATTENTION", "1") == "1" _sageattention_runtime_fallback_logged = False +# Check for MPS device - SageAttention doesn't support MPS +_is_mps_device = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + +if use_sage_attention and _is_mps_device: + logger.info("SageAttention disabled - MPS device detected (not supported by SageAttention)") + use_sage_attention = False + if use_sage_attention: try: from sageattention import sageattn # type: ignore[reportMissingImports] @@ -107,12 +192,17 @@ def _get_device() -> torch.device: return torch.device("cuda") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return torch.device("mps") + # Fallback to CPU if no GPU available + logger.warning("No CUDA or MPS device available, using CPU") return torch.device("cpu") DEVICE = _get_device() DTYPE = torch.bfloat16 +# Setup CUDA fallback for non-CUDA PyTorch builds (MPS/CPU support) +_setup_cuda_fallback() + def _resolve_app_data_dir() -> Path: env_path = os.environ.get("LTX_APP_DATA_DIR") if not env_path: