diff --git a/python/extension-impl/src/extension.rs b/python/extension-impl/src/extension.rs index 8417598f1..cb8e2127d 100644 --- a/python/extension-impl/src/extension.rs +++ b/python/extension-impl/src/extension.rs @@ -239,6 +239,18 @@ impl Trainer { } Ok(()) } + + pub fn truncate_bf16(self_: PyRef<'_, Self>) -> PyResult<()> { + let trainer = self_.trainer.write().unwrap().take(); + if let Some(mut trainer) = trainer { + let trainer = self_.py().allow_threads(move || { + let _ = trainer.truncate_bf16(); + trainer + }); + *self_.trainer.write().unwrap() = Some(trainer); + } + Ok(()) + } } #[pymodule] diff --git a/python/python/psyche/sidecar/__main__.py b/python/python/psyche/sidecar/__main__.py index d20f59d1a..c17d80142 100644 --- a/python/python/psyche/sidecar/__main__.py +++ b/python/python/psyche/sidecar/__main__.py @@ -318,6 +318,14 @@ def barrier(): with torch.no_grad(): trainer.extract() + elif operation["operation"] == "truncate_bf16": + if trainer is None: + raise RuntimeError( + "Got truncate_bf16 operation without having created a trainer" + ) + + with torch.no_grad(): + trainer.truncate_bf16() elif operation["operation"] == "forward": with torch.no_grad(): forward = ForwardOperation(**operation) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 713e8c0c2..919735945 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -131,7 +131,7 @@ impl CooldownStepMetadata { mut trainers: Vec, state: &Coordinator, ) -> Result { - let Some(mut trainer) = trainers.pop() else { + let Some(trainer) = trainers.pop() else { return Err(CooldownError::NoTrainers); }; @@ -151,7 +151,13 @@ impl CooldownStepMetadata { event!(cooldown::ModelSerializationStarted); let (variables, trainer) = tokio::task::spawn_blocking::<_, Result<_, CheckpointError>>(|| { - let variables = trainer.extract()?; + let mut trainer = trainer; + trainer.truncate_bf16()?; + let variables: HashMap = trainer + .extract()? + .into_iter() + .map(|(name, tensor)| (name, tensor.to_kind(tch::Kind::BFloat16))) + .collect(); info!("Model extracted; {} parameters", variables.len()); Ok((variables, trainer)) }) diff --git a/shared/modeling/src/python_distributed_trainer.rs b/shared/modeling/src/python_distributed_trainer.rs index 74e06fe32..3fd518a33 100644 --- a/shared/modeling/src/python_distributed_trainer.rs +++ b/shared/modeling/src/python_distributed_trainer.rs @@ -339,6 +339,30 @@ impl PythonDistributedTrainer { Ok(result) } + pub fn truncate_bf16(&mut self) -> Result<(), TrainerThreadCommunicationError> { + let operation = serde_json::json!({ + "operation": "truncate_bf16", + }); + + let iteration = self.iteration.fetch_add(1, Ordering::Relaxed); + trace!( + "Sending truncate_bf16 operation to Python clients, iteration = {}", + iteration + ); + + self.comm + .set(&iteration.to_string(), &operation.to_string())?; + + // barrier to ensure everyone has seen the broadcast + let dummy = Tensor::zeros([], (Kind::Float, self.device)); + self.comm.all_reduce(&dummy, ReduceType::Sum)?; + + self.local.truncate_bf16()?; + trace!("Truncate bf16 operation complete on all Python clients"); + + Ok(()) + } + fn broadcast_distro_results(&self, distro_results: &[DistroResults]) -> PyResult<()> { let first = distro_results.first().unwrap(); let params = first.len(); diff --git a/shared/modeling/src/trainer.rs b/shared/modeling/src/trainer.rs index 936cb12ba..9b3cb7c9e 100644 --- a/shared/modeling/src/trainer.rs +++ b/shared/modeling/src/trainer.rs @@ -271,6 +271,7 @@ enum ParallelAssignment { loss_scale: Option, }, Extract, + TruncateBf16, } #[derive(Debug)] @@ -288,6 +289,7 @@ enum ParallelResult { Extract { variables: HashMap, }, + TruncateBf16, } #[derive(Debug)] @@ -357,6 +359,14 @@ impl Trainer { } } + pub fn truncate_bf16(&mut self) -> Result<(), TrainerThreadCommunicationError> { + match self { + Trainer::Local(local_trainer) => local_trainer.truncate_bf16(), + #[cfg(feature = "python")] + Trainer::PythonDistributed(python) => python.truncate_bf16(), + } + } + pub fn can_do_inference(&self) -> bool { match self { Trainer::Local(local_trainer) => local_trainer.can_do_inference(), @@ -705,6 +715,28 @@ impl LocalTrainer { Ok(extracted) } + pub fn truncate_bf16(&mut self) -> Result<(), TrainerThreadCommunicationError> { + self.barrier.reset(); + for (tx, _) in &self.models { + tx.send(ParallelAssignment::TruncateBf16) + .map_err(|_| TrainerThreadCommunicationError::SendCommand)?; + } + for (_, rx) in &self.models { + match rx + .recv() + .map_err(|_| TrainerThreadCommunicationError::RecvResult)? + { + ParallelResult::TruncateBf16 => {} + result => { + return Err(TrainerThreadCommunicationError::UnexpectedResult(format!( + "{result:?}" + ))); + } + } + } + Ok(()) + } + // todo: refactor args into a struct #[allow(clippy::too_many_arguments)] fn model_thread( @@ -1111,6 +1143,18 @@ impl LocalTrainer { } } } + Ok(ParallelAssignment::TruncateBf16) => { + let _no_grad = tch::no_grad_guard(); + for var in model.variables() { + let mut tensor = var.local_tensor(); + let original_kind = tensor.kind(); + let truncated = tensor.to_kind(Kind::BFloat16).to_kind(original_kind); + tensor.copy_(&truncated); + } + if submission.send(ParallelResult::TruncateBf16).is_err() { + return; + } + } Err(_) => { return; }