Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions backend/ltx2_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,97 @@
logging.basicConfig(level=logging.INFO, handlers=[console_handler])
logger = logging.getLogger(__name__)

# ============================================================
# CUDA Fallback Handling (for non-CUDA PyTorch builds on MPS/CPU)
# ============================================================

def _setup_cuda_fallback() -> None:
"""
Monkey-patch torch.cuda functions when PyTorch was not compiled with CUDA
support (i.e., torch.version.cuda is None, as on MPS-only or CPU-only
wheels).

The ltx-pipelines library calls torch.cuda.synchronize() unconditionally,
which raises "Torch not compiled with CUDA enabled" on non-CUDA builds.
This patch is intentionally limited to non-CUDA builds so that CUDA-capable
installations that happen to be running on CPU (e.g., driver temporarily
unavailable) still surface real CUDA misconfiguration errors instead of
silently no-oping them.
"""
# Only patch when PyTorch has no CUDA support compiled in.
if torch.version.cuda is not None:
return

device_type = DEVICE.type
logger.info(f"Setup CUDA fallback for non-CUDA PyTorch build (device: {device_type})")

# Create safe no-op implementations for CUDA functions
def safe_cuda_synchronize(device: object = None) -> None:
"""No-op synchronize for non-CUDA devices; delegates to MPS when available."""
if device_type == "mps":
try:
torch.mps.synchronize()
except (RuntimeError, AttributeError) as exc:
logger.debug("MPS synchronize fallback failed: %s", exc)

def safe_cuda_empty_cache() -> None:
"""No-op empty_cache for non-CUDA devices; delegates to MPS when available."""
if device_type == "mps":
try:
torch.mps.empty_cache()
except (RuntimeError, AttributeError) as exc:
logger.debug("MPS empty_cache fallback failed: %s", exc)

def safe_cuda_memory_reserved(device: object = None) -> int:
"""Return 0 for memory reserved on non-CUDA devices."""
return 0

def safe_cuda_memory_allocated(device: object = None) -> int:
"""Return 0 for memory allocated on non-CUDA devices."""
return 0

def safe_cuda_get_device_name(device: object = None) -> str:
"""Return device name for non-CUDA devices."""
if device_type == "mps" and hasattr(torch, 'mps'):
return "Apple Silicon MPS"
return "CPU"

def safe_cuda_get_device_capability(device: object = None) -> tuple[int, int]:
"""Return (0, 0) for non-CUDA devices."""
return (0, 0)

# Patch torch.cuda module
if not hasattr(torch.cuda, "_ltx_original_synchronize"):
# Store original functions if they exist
try:
torch.cuda._ltx_original_synchronize = torch.cuda.synchronize # type: ignore[attr-defined]
except AttributeError:
pass

# Replace with safe implementations
torch.cuda.synchronize = safe_cuda_synchronize # type: ignore[assignment]
torch.cuda.empty_cache = safe_cuda_empty_cache # type: ignore[assignment]
torch.cuda.memory_reserved = safe_cuda_memory_reserved # type: ignore[assignment]
torch.cuda.memory_allocated = safe_cuda_memory_allocated # type: ignore[assignment]
torch.cuda.get_device_name = safe_cuda_get_device_name # type: ignore[assignment]
torch.cuda.get_device_capability = safe_cuda_get_device_capability # type: ignore[assignment]

logger.info("CUDA fallback patch applied successfully")


# ============================================================
# SageAttention Integration
# ============================================================
use_sage_attention = os.environ.get("USE_SAGE_ATTENTION", "1") == "1"
_sageattention_runtime_fallback_logged = False

# Check for MPS device - SageAttention doesn't support MPS
_is_mps_device = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()

if use_sage_attention and _is_mps_device:
logger.info("SageAttention disabled - MPS device detected (not supported by SageAttention)")
use_sage_attention = False

if use_sage_attention:
try:
from sageattention import sageattn # type: ignore[reportMissingImports]
Expand Down Expand Up @@ -107,12 +192,17 @@ def _get_device() -> torch.device:
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
# Fallback to CPU if no GPU available
logger.warning("No CUDA or MPS device available, using CPU")
return torch.device("cpu")


DEVICE = _get_device()
DTYPE = torch.bfloat16

# Setup CUDA fallback for non-CUDA PyTorch builds (MPS/CPU support)
_setup_cuda_fallback()

def _resolve_app_data_dir() -> Path:
env_path = os.environ.get("LTX_APP_DATA_DIR")
if not env_path:
Expand Down
Loading