From bae2ee7cf6bfca7c44fa6644c6cd22a7c0d3cc01 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Fri, 13 Mar 2026 11:58:48 -0300 Subject: [PATCH 1/2] Reduce the precision of the tensors to share the model --- config/solana-test/light-config.toml | 15 ++++++++------- scripts/train-solana-test.sh | 5 +++-- shared/client/src/state/cooldown.rs | 6 +++++- shared/network/src/p2p_model_sharing.rs | 2 +- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/config/solana-test/light-config.toml b/config/solana-test/light-config.toml index eab015342..2b607eeaa 100644 --- a/config/solana-test/light-config.toml +++ b/config/solana-test/light-config.toml @@ -1,27 +1,28 @@ [config] -warmup_time = 30 +warmup_time = 120 cooldown_time = 30 -epoch_time = 60 -max_round_train_time = 15 +epoch_time = 240 +max_round_train_time = 30 round_witness_time = 5 min_clients = 1 init_min_clients = 1 verification_percent = 0 witness_nodes = 0 -global_batch_size_start = 8 -global_batch_size_end = 8 +global_batch_size_start = 16 +global_batch_size_end = 16 global_batch_size_warmup_tokens = 0 total_steps = 25000 waiting_for_members_extra_time = 3 [model.LLM] -architecture = "HfLlama" +architecture = "Torchtitan" data_type = "Pretraining" max_seq_len = 2048 cold_start_warmup_steps = 0 [model.LLM.checkpoint.Hub] -repo_id = "emozilla/llama2-20m-init" +# repo_id = "emozilla/llama2-20m-init" +repo_id = "NousResearch/Meta-Llama-3.1-8B" [model.LLM.data_location.Http] token_size_in_bytes = "TwoBytes" diff --git a/scripts/train-solana-test.sh b/scripts/train-solana-test.sh index 55699e591..4e12231b4 100755 --- a/scripts/train-solana-test.sh +++ b/scripts/train-solana-test.sh @@ -26,7 +26,8 @@ RUN_ID=${RUN_ID:-"test"} AUTHORIZER=${AUTHORIZER:-"11111111111111111111111111111111"} # presets for a DGX or an HGX -DP=${DP:-"8"} +# DP=${DP:-"8"} +DP="8" TP=${TP:-"1"} BATCH_SIZE=${BATCH_SIZE:-"1"} @@ -36,7 +37,7 @@ solana airdrop 10 "$(solana-keygen pubkey ${WALLET_FILE})" --url "${RPC}" || tru export RUST_LOG="info,psyche=debug" if [[ "$OTLP_METRICS_URL" == "" ]]; then - cargo run --release --bin psyche-solana-client -- \ + cargo run --release --features python --bin psyche-solana-client -- \ train \ --wallet-private-key-path ${WALLET_FILE} \ --rpc ${RPC} \ diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 713e8c0c2..e150b355f 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -151,7 +151,11 @@ impl CooldownStepMetadata { event!(cooldown::ModelSerializationStarted); let (variables, trainer) = tokio::task::spawn_blocking::<_, Result<_, CheckpointError>>(|| { - let variables = trainer.extract()?; + let variables: HashMap = trainer + .extract()? + .into_iter() + .map(|(name, tensor)| (name, tensor.to_kind(tch::Kind::BFloat16))) + .collect(); info!("Model extracted; {} parameters", variables.len()); Ok((variables, trainer)) }) diff --git a/shared/network/src/p2p_model_sharing.rs b/shared/network/src/p2p_model_sharing.rs index ffd85f8f8..dd702b1fd 100644 --- a/shared/network/src/p2p_model_sharing.rs +++ b/shared/network/src/p2p_model_sharing.rs @@ -8,7 +8,7 @@ use psyche_event_sourcing::event; use std::collections::{HashMap, HashSet, VecDeque, hash_map::Entry}; use std::io::{Cursor, Write}; use std::time::Duration; -use tch::Tensor; +use tch::{Kind, Tensor}; use thiserror::Error; use tokenizers::Tokenizer; use tokio::sync::{ From 3ffdced71abb00db74ca03181815425b3135088f Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Thu, 19 Mar 2026 10:15:43 -0300 Subject: [PATCH 2/2] Truncate bf16 for all different ranks --- config/solana-test/light-config.toml | 15 +++---- python/extension-impl/src/extension.rs | 12 +++++ python/python/psyche/sidecar/__main__.py | 8 ++++ scripts/train-solana-test.sh | 5 +-- shared/client/src/state/cooldown.rs | 4 +- .../src/python_distributed_trainer.rs | 24 ++++++++++ shared/modeling/src/trainer.rs | 44 +++++++++++++++++++ shared/network/src/p2p_model_sharing.rs | 2 +- 8 files changed, 101 insertions(+), 13 deletions(-) diff --git a/config/solana-test/light-config.toml b/config/solana-test/light-config.toml index 2b607eeaa..eab015342 100644 --- a/config/solana-test/light-config.toml +++ b/config/solana-test/light-config.toml @@ -1,28 +1,27 @@ [config] -warmup_time = 120 +warmup_time = 30 cooldown_time = 30 -epoch_time = 240 -max_round_train_time = 30 +epoch_time = 60 +max_round_train_time = 15 round_witness_time = 5 min_clients = 1 init_min_clients = 1 verification_percent = 0 witness_nodes = 0 -global_batch_size_start = 16 -global_batch_size_end = 16 +global_batch_size_start = 8 +global_batch_size_end = 8 global_batch_size_warmup_tokens = 0 total_steps = 25000 waiting_for_members_extra_time = 3 [model.LLM] -architecture = "Torchtitan" +architecture = "HfLlama" data_type = "Pretraining" max_seq_len = 2048 cold_start_warmup_steps = 0 [model.LLM.checkpoint.Hub] -# repo_id = "emozilla/llama2-20m-init" -repo_id = "NousResearch/Meta-Llama-3.1-8B" +repo_id = "emozilla/llama2-20m-init" [model.LLM.data_location.Http] token_size_in_bytes = "TwoBytes" diff --git a/python/extension-impl/src/extension.rs b/python/extension-impl/src/extension.rs index 8417598f1..cb8e2127d 100644 --- a/python/extension-impl/src/extension.rs +++ b/python/extension-impl/src/extension.rs @@ -239,6 +239,18 @@ impl Trainer { } Ok(()) } + + pub fn truncate_bf16(self_: PyRef<'_, Self>) -> PyResult<()> { + let trainer = self_.trainer.write().unwrap().take(); + if let Some(mut trainer) = trainer { + let trainer = self_.py().allow_threads(move || { + let _ = trainer.truncate_bf16(); + trainer + }); + *self_.trainer.write().unwrap() = Some(trainer); + } + Ok(()) + } } #[pymodule] diff --git a/python/python/psyche/sidecar/__main__.py b/python/python/psyche/sidecar/__main__.py index d20f59d1a..c17d80142 100644 --- a/python/python/psyche/sidecar/__main__.py +++ b/python/python/psyche/sidecar/__main__.py @@ -318,6 +318,14 @@ def barrier(): with torch.no_grad(): trainer.extract() + elif operation["operation"] == "truncate_bf16": + if trainer is None: + raise RuntimeError( + "Got truncate_bf16 operation without having created a trainer" + ) + + with torch.no_grad(): + trainer.truncate_bf16() elif operation["operation"] == "forward": with torch.no_grad(): forward = ForwardOperation(**operation) diff --git a/scripts/train-solana-test.sh b/scripts/train-solana-test.sh index 4e12231b4..55699e591 100755 --- a/scripts/train-solana-test.sh +++ b/scripts/train-solana-test.sh @@ -26,8 +26,7 @@ RUN_ID=${RUN_ID:-"test"} AUTHORIZER=${AUTHORIZER:-"11111111111111111111111111111111"} # presets for a DGX or an HGX -# DP=${DP:-"8"} -DP="8" +DP=${DP:-"8"} TP=${TP:-"1"} BATCH_SIZE=${BATCH_SIZE:-"1"} @@ -37,7 +36,7 @@ solana airdrop 10 "$(solana-keygen pubkey ${WALLET_FILE})" --url "${RPC}" || tru export RUST_LOG="info,psyche=debug" if [[ "$OTLP_METRICS_URL" == "" ]]; then - cargo run --release --features python --bin psyche-solana-client -- \ + cargo run --release --bin psyche-solana-client -- \ train \ --wallet-private-key-path ${WALLET_FILE} \ --rpc ${RPC} \ diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index e150b355f..919735945 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -131,7 +131,7 @@ impl CooldownStepMetadata { mut trainers: Vec, state: &Coordinator, ) -> Result { - let Some(mut trainer) = trainers.pop() else { + let Some(trainer) = trainers.pop() else { return Err(CooldownError::NoTrainers); }; @@ -151,6 +151,8 @@ impl CooldownStepMetadata { event!(cooldown::ModelSerializationStarted); let (variables, trainer) = tokio::task::spawn_blocking::<_, Result<_, CheckpointError>>(|| { + let mut trainer = trainer; + trainer.truncate_bf16()?; let variables: HashMap = trainer .extract()? .into_iter() diff --git a/shared/modeling/src/python_distributed_trainer.rs b/shared/modeling/src/python_distributed_trainer.rs index 74e06fe32..3fd518a33 100644 --- a/shared/modeling/src/python_distributed_trainer.rs +++ b/shared/modeling/src/python_distributed_trainer.rs @@ -339,6 +339,30 @@ impl PythonDistributedTrainer { Ok(result) } + pub fn truncate_bf16(&mut self) -> Result<(), TrainerThreadCommunicationError> { + let operation = serde_json::json!({ + "operation": "truncate_bf16", + }); + + let iteration = self.iteration.fetch_add(1, Ordering::Relaxed); + trace!( + "Sending truncate_bf16 operation to Python clients, iteration = {}", + iteration + ); + + self.comm + .set(&iteration.to_string(), &operation.to_string())?; + + // barrier to ensure everyone has seen the broadcast + let dummy = Tensor::zeros([], (Kind::Float, self.device)); + self.comm.all_reduce(&dummy, ReduceType::Sum)?; + + self.local.truncate_bf16()?; + trace!("Truncate bf16 operation complete on all Python clients"); + + Ok(()) + } + fn broadcast_distro_results(&self, distro_results: &[DistroResults]) -> PyResult<()> { let first = distro_results.first().unwrap(); let params = first.len(); diff --git a/shared/modeling/src/trainer.rs b/shared/modeling/src/trainer.rs index 936cb12ba..9b3cb7c9e 100644 --- a/shared/modeling/src/trainer.rs +++ b/shared/modeling/src/trainer.rs @@ -271,6 +271,7 @@ enum ParallelAssignment { loss_scale: Option, }, Extract, + TruncateBf16, } #[derive(Debug)] @@ -288,6 +289,7 @@ enum ParallelResult { Extract { variables: HashMap, }, + TruncateBf16, } #[derive(Debug)] @@ -357,6 +359,14 @@ impl Trainer { } } + pub fn truncate_bf16(&mut self) -> Result<(), TrainerThreadCommunicationError> { + match self { + Trainer::Local(local_trainer) => local_trainer.truncate_bf16(), + #[cfg(feature = "python")] + Trainer::PythonDistributed(python) => python.truncate_bf16(), + } + } + pub fn can_do_inference(&self) -> bool { match self { Trainer::Local(local_trainer) => local_trainer.can_do_inference(), @@ -705,6 +715,28 @@ impl LocalTrainer { Ok(extracted) } + pub fn truncate_bf16(&mut self) -> Result<(), TrainerThreadCommunicationError> { + self.barrier.reset(); + for (tx, _) in &self.models { + tx.send(ParallelAssignment::TruncateBf16) + .map_err(|_| TrainerThreadCommunicationError::SendCommand)?; + } + for (_, rx) in &self.models { + match rx + .recv() + .map_err(|_| TrainerThreadCommunicationError::RecvResult)? + { + ParallelResult::TruncateBf16 => {} + result => { + return Err(TrainerThreadCommunicationError::UnexpectedResult(format!( + "{result:?}" + ))); + } + } + } + Ok(()) + } + // todo: refactor args into a struct #[allow(clippy::too_many_arguments)] fn model_thread( @@ -1111,6 +1143,18 @@ impl LocalTrainer { } } } + Ok(ParallelAssignment::TruncateBf16) => { + let _no_grad = tch::no_grad_guard(); + for var in model.variables() { + let mut tensor = var.local_tensor(); + let original_kind = tensor.kind(); + let truncated = tensor.to_kind(Kind::BFloat16).to_kind(original_kind); + tensor.copy_(&truncated); + } + if submission.send(ParallelResult::TruncateBf16).is_err() { + return; + } + } Err(_) => { return; } diff --git a/shared/network/src/p2p_model_sharing.rs b/shared/network/src/p2p_model_sharing.rs index dd702b1fd..ffd85f8f8 100644 --- a/shared/network/src/p2p_model_sharing.rs +++ b/shared/network/src/p2p_model_sharing.rs @@ -8,7 +8,7 @@ use psyche_event_sourcing::event; use std::collections::{HashMap, HashSet, VecDeque, hash_map::Entry}; use std::io::{Cursor, Write}; use std::time::Duration; -use tch::{Kind, Tensor}; +use tch::Tensor; use thiserror::Error; use tokenizers::Tokenizer; use tokio::sync::{