From 3298b777f6acd2e4b0ab37f4c8ca6809dc2cf3e6 Mon Sep 17 00:00:00 2001 From: rob-maron <132852777+rob-maron@users.noreply.github.com> Date: Tue, 10 Mar 2026 21:48:51 -0400 Subject: [PATCH] cuda smoke test changes --- shared/client/src/state/init.rs | 9 ++++ shared/modeling/src/device_utils.rs | 69 +++++++++++++++++++++++++---- 2 files changed, 70 insertions(+), 8 deletions(-) diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index ec5d38d69..67bf23bd6 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -118,6 +118,9 @@ pub enum InitRunError { #[error("Unsupported architecture: {0}")] UnsupportedArchitecture(String), + #[error("Device is not usable: {0}")] + DeviceNotUsable(anyhow::Error), + #[cfg(feature = "python")] #[error("Python distributed error: {0}")] PythonDistributedError(#[from] psyche_modeling::PythonDistributedCausalLMError), @@ -193,6 +196,12 @@ impl RunInitConfigAndIO { )); } + // Fail fast if the device is not usable + init_config + .device + .ensure_usable() + .map_err(InitRunError::DeviceNotUsable)?; + let model::Model::LLM(llm) = state.model; let hub_read_token = init_config.hub_read_token.clone(); diff --git a/shared/modeling/src/device_utils.rs b/shared/modeling/src/device_utils.rs index 71a073f90..adf777890 100644 --- a/shared/modeling/src/device_utils.rs +++ b/shared/modeling/src/device_utils.rs @@ -1,9 +1,8 @@ use std::{fmt, str::FromStr}; +use anyhow::Context; use itertools::Itertools; -use tch::{Device, utils::has_mps}; -#[cfg(test)] -use tch::{Kind, Tensor}; +use tch::{utils::has_mps, Device, Kind, Tensor}; use thiserror::Error; /// Get all available CUDA devices @@ -99,6 +98,22 @@ impl Devices { Devices::Mps => has_mps(), } } + + /// Ensures the device is usable + /// + /// Currently does nothing for CPU and MPS devices + pub fn ensure_usable(&self) -> anyhow::Result<()> { + match self { + Devices::Cpu | Devices::Mps => Ok(()), + Devices::Cuda(device_indices) => { + for &device_idx in device_indices { + ensure_cuda_device_usable(device_idx) + .with_context(|| format!("cuda:{device_idx} is not usable"))?; + } + Ok(()) + } + } + } } /// Get all available devices, for debugging purposes @@ -217,6 +232,46 @@ impl DevicePytorchStr for Device { } } +/// Ensures that a CUDA device is usable +/// +/// The causal SDPA with a large sequence length is designed to trigger cuDNN's +/// runtime kernel compilation +fn ensure_cuda_device_usable(device_idx: usize) -> anyhow::Result<()> { + let device = Device::Cuda(device_idx); + + let batch: i64 = 2; + let heads: i64 = 32; + let seq_len: i64 = 4096; + let head_dim: i64 = 128; + + let q = Tensor::f_randn([batch, heads, seq_len, head_dim], (Kind::BFloat16, device))? + .f_set_requires_grad(true) + .context("failed to set requires grad")?; + let k = Tensor::f_randn([batch, heads, seq_len, head_dim], (Kind::BFloat16, device))? + .f_set_requires_grad(true) + .context("failed to set requires grad")?; + let v = Tensor::f_randn([batch, heads, seq_len, head_dim], (Kind::BFloat16, device))? + .f_set_requires_grad(true) + .context("failed to set requires grad")?; + + let output = Tensor::f_scaled_dot_product_attention::( + &q, &k, &v, None, // attn_mask + 0.0, // dropout_p + true, // is_causal + None, // scale (None = default 1/sqrt(head_dim)) + false, // enable_gqa + ) + .context("failed to run SDPA operation")?; + + let loss = output + .f_sum(Kind::BFloat16) + .context("failed to sum SDPA output")?; + loss.f_backward() + .context("failed to run backward pass on SDPA output")?; + + Ok(()) +} + #[cfg(test)] mod tests { use tch::utils::has_cuda; @@ -242,11 +297,9 @@ mod tests { .unwrap(), Devices::Cuda((0..tch::Cuda::device_count() as usize).collect()) ); - assert!( - format!("cuda:{}", tch::Cuda::device_count()) - .parse::() - .is_err() - ); + assert!(format!("cuda:{}", tch::Cuda::device_count()) + .parse::() + .is_err()); } else { assert!(matches!( "cuda".parse::(),